5
from torch.testing._internal.common_utils import (
6
TestCase, TEST_WITH_TORCHDYNAMO, run_tests, skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, parametrize,
7
instantiate_parametrized_tests, TemporaryFileName)
12
from torch.testing._internal.jit_utils import RUN_CUDA
13
from torch._guards import tracing, TracingContext
14
from torch._subclasses.fake_tensor import (
16
extract_tensor_metadata,
20
DynamicOutputShapeException,
21
UnsupportedOperatorException,
22
unset_fake_temporarily,
24
from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols, StatelessSymbolicContext
25
from torch.testing._internal.custom_op_db import custom_op_db
26
from torch.testing._internal.common_device_type import ops
27
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
28
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
29
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
30
from torch._dynamo.testing import rand_strided
31
from torch._C._functorch import is_batchedtensor, _add_batch_dim, get_unwrapped
32
from torch.testing import FileCheck
36
import torch._prims as prims
41
import torch._functorch.config
42
import torch.testing._internal.optests as optests
43
from unittest.mock import patch
45
from torch import distributed as dist
46
from torch.utils._mode_utils import no_dispatch
47
from torch.utils._python_dispatch import TorchDispatchMode
48
import torch.utils._pytree as pytree
52
torch._dynamo.config.fake_tensor_cache_enabled = True
53
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
55
class FakeTensorTest(TestCase):
56
def checkType(self, t, device_str, size):
57
self.assertTrue(isinstance(t, FakeTensor))
58
self.assertEqual(t.device.type, device_str)
59
self.assertEqual(list(t.size()), size)
62
@unittest.skipIf(not RUN_CUDA, "requires cuda")
63
def test_cuda_initialized(self):
65
with FakeTensorMode():
66
p = torch.randn(4, 2, requires_grad=True, device='cuda')
67
x = torch.randn(8, 4, device='cuda')
68
y = torch.mm(x, p).square().sum()
72
x = torch.empty(2, 2, device="cpu")
73
y = torch.empty(4, 2, 2, device="cpu")
74
with FakeTensorMode() as mode:
75
x = mode.from_tensor(x)
76
y = mode.from_tensor(y)
78
self.assertEqual(z.shape, (4, 2, 2))
79
self.assertEqual(z.device, torch.device("cpu"))
80
self.assertTrue(isinstance(z, FakeTensor))
82
def test_basic_forced_memo_only(self):
83
x = torch.empty(2, 2, device="cpu")
84
y = torch.empty(4, 2, 2, device="cpu")
85
with FakeTensorMode() as mode:
86
x_fake = mode.from_tensor(x)
87
x2 = mode.from_tensor(x, memoized_only=True)
88
self.assertTrue(x2 is not None)
89
y = mode.from_tensor(y, memoized_only=True)
90
self.assertIs(y, None)
92
def test_custom_op_fallback(self):
93
from torch.library import Library, impl
95
test_lib = Library("my_test_op", "DEF")
96
test_lib.define('foo(Tensor self) -> Tensor')
98
@impl(test_lib, 'foo', 'CPU')
102
x = torch.empty(2, 2, device="cpu")
103
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
104
with FakeTensorMode(allow_fallback_kernels=True) as mode:
105
x = mode.from_tensor(x)
106
torch.ops.my_test_op.foo(x)
108
def test_parameter_instantiation(self):
109
with FakeTensorMode():
111
y = torch.nn.parameter.Parameter(x)
112
self.assertTrue(isinstance(y, torch.nn.Parameter))
114
@unittest.skipIf(not dist.is_available(), "requires distributed")
115
def test_fsdp_flat_param(self):
116
from torch.distributed.fsdp._flat_param import FlatParameter
117
with FakeTensorMode() as m:
118
data = torch.randn(2, 2)
119
param = FlatParameter(data, requires_grad=True)
120
self.assertIsInstance(param, FlatParameter)
121
self.assertIsInstance(param, torch.nn.Parameter)
122
self.assertIsInstance(param, FakeTensor)
124
def test_non_parameter_grad(self):
125
mode = FakeTensorMode()
126
t = torch.rand([4], requires_grad=True)
127
fake_t = mode.from_tensor(t)
128
self.assertEqual(fake_t.requires_grad, t.requires_grad)
130
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
131
@unittest.skipIf(not RUN_CUDA, "requires cuda")
132
def test_index_cuda_with_cpu(self):
133
with FakeTensorMode():
134
x = torch.rand([2048], device='cuda')
135
out = x[torch.zeros([36], dtype=torch.int64)]
136
self.checkType(out, "cuda", [36])
138
@unittest.skipIf(not RUN_CUDA, "requires cuda")
139
def test_shape_take_not_device(self):
140
with FakeTensorMode():
141
x = torch.empty(1, device="cpu")
142
y = torch.empty(8, 8, device="cuda")
143
out = x.resize_as_(y)
144
self.assertEqual(out.shape, (8, 8))
145
self.assertEqual(out.device.type, "cpu")
146
self.assertTrue(isinstance(out, FakeTensor))
149
with FakeTensorMode():
150
x = torch.empty(2, 2, device="cpu")
151
self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
152
x = torch.empty(2, 2, device="meta")
153
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
155
@unittest.skipIf(not RUN_CUDA, "requires cuda")
156
def test_zero_dim(self):
157
with FakeTensorMode() as mode:
159
y = torch.rand([4, 4], device="cuda")
161
self.assertEqual(out.shape, (4, 4))
162
self.assertEqual(out.device, y.device)
163
self.assertTrue(isinstance(out, FakeTensor))
165
def test_nan_to_num(self):
166
with FakeTensorMode():
167
for dtype in [torch.float16, torch.float32]:
168
x = torch.rand([4], dtype=dtype)
169
y = torch.nan_to_num(x, nan=None)
170
z = torch.nan_to_num(x, 0.0)
171
self.assertEqual(dtype, y.dtype)
172
self.assertEqual(dtype, z.dtype)
174
@unittest.skipIf(not RUN_CUDA, "requires cuda")
175
def test_throw(self):
177
with FakeTensorMode() as mode:
178
x_conv = mode.from_tensor(x)
179
y = torch.rand([4, 4], device="cuda")
180
z = torch.rand([4, 4], device="cpu")
181
self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
183
@unittest.skipIf(not RUN_CUDA, "requires cuda")
184
def test_type_as(self):
185
with FakeTensorMode():
186
x = torch.rand([16, 1], device="cpu")
187
y = torch.rand([4, 4], device="cuda")
189
self.assertEqual(out.device.type, "cuda")
190
self.assertTrue(isinstance(out, FakeTensor))
192
@unittest.skipIf(not RUN_CUDA, "requires cuda")
193
def test_setitem(self):
194
for device in ["cpu", "cuda"]:
195
with FakeTensorMode():
196
x = torch.rand([16, 1], device=device)
199
@unittest.skipIf(not RUN_CUDA, "requires cuda")
200
def test_device_inplace_copy(self):
201
with FakeTensorMode():
202
x = torch.rand([8, 8], device="cpu")
203
y = torch.rand([8, 8], device="cuda")
204
assert x.copy_(y).device.type == "cpu"
205
assert y.copy_(x).device.type == "cuda"
207
def test_fake_dispatch_keys(self):
208
with FakeTensorMode():
210
f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
211
f.run(torch._C._dispatch_key_set(x))
213
with torch.inference_mode():
216
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
217
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
219
def test_batch_tensor(self):
220
x = torch.rand((3, 4, 5))
221
b = _add_batch_dim(x, 0, 0)
222
mode = FakeTensorMode()
223
fake_b = mode.from_tensor(b)
224
prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
226
b1 = _add_batch_dim(x, 1, 1)
227
b2 = _add_batch_dim(b1, 0, 2)
228
fake_b2 = mode.from_tensor(b2)
229
prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
230
self.assertTrue(is_batchedtensor(fake_b2))
231
fake_b1 = get_unwrapped(fake_b2)
232
self.assertTrue(is_batchedtensor(fake_b1))
233
fake_tensor = get_unwrapped(fake_b1)
234
self.assertIsInstance(fake_tensor, FakeTensor)
236
def test_constructor(self):
237
with FakeTensorMode():
238
x = torch.rand([4, 4], device="cpu")
240
self.assertTrue(isinstance(x, FakeTensor))
241
self.assertTrue(x.device.type == "cpu")
244
with FakeTensorMode():
245
y = torch.rand([4], device="cpu")
248
self.assertTrue(isinstance(out, FakeTensor))
252
with torch._subclasses.CrossRefFakeMode():
253
y = torch.full((4, 4), 1)
255
def check_function_with_fake(self, fn):
257
with torch._subclasses.FakeTensorMode():
260
for a, b in zip(pytree.tree_leaves(out), pytree.tree_leaves(out_fake)):
261
if not isinstance(a, torch.Tensor):
262
self.assertTrue(not isinstance(b, torch.Tensor))
265
prims.utils.compare_tensor_meta(a, b, check_strides=True)
267
@unittest.skipIf(not RUN_CUDA, "requires cuda")
268
def test_non_kwarg_device(self):
269
with FakeTensorMode():
270
x = torch.rand([16, 1], device="cpu")
271
y = x.to(torch.device("cpu"))
273
z = x.to(torch.device("cuda"))
274
self.assertEqual(z.device.type, "cuda")
276
def test_non_overlapping_stride_zero(self):
278
x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
281
self.check_function_with_fake(foo)
283
def test_fake_mode_error(self):
284
x = torch.rand([4, 4])
286
with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
287
with FakeTensorMode():
290
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
291
def test_fake_grad_copy(self):
292
x = torch.rand([4, 4], requires_grad=True)
293
x.grad = torch.rand([4, 4])
294
mode = FakeTensorMode()
295
fake_x = mode.from_tensor(x)
296
prims.utils.compare_tensor_meta(fake_x, x)
297
prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
299
self.assertTrue(isinstance(fake_x.grad, FakeTensor))
301
@unittest.skipIf(not RUN_CUDA, "requires cuda")
302
def test_index_put_error(self):
303
mode = FakeTensorMode()
304
for context in [contextlib.nullcontext, lambda: mode]:
306
y = torch.randn(2, 2, 3)
307
x = torch.randn(2, 2, 3).to('cuda')
308
with self.assertRaises(RuntimeError):
311
with self.assertRaises(RuntimeError):
312
torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), y)
315
torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.))
316
torch.ops.aten.index_put_(x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.))
320
@unittest.skipIf(not RUN_CUDA, "requires cuda")
321
def test_like_constructor(self):
322
with FakeTensorMode():
323
x = torch.rand([4, 4])
324
y = torch.ones_like(x)
325
self.assertTrue(isinstance(y, FakeTensor))
326
self.assertEqual(y.device.type, "cpu")
327
z = torch.ones_like(x, device="cuda")
328
self.assertTrue(isinstance(z, FakeTensor))
329
self.assertEqual(z.device.type, "cuda")
331
def test_binary_op_type_promotion(self):
332
with FakeTensorMode():
333
x = torch.empty([2, 2], dtype=torch.float)
334
y = torch.empty([2, 2], dtype=torch.int64)
336
self.assertEqual(out.dtype, torch.float)
337
self.assertEqual(out.device.type, "cpu")
339
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
340
def test_from_numpy(self):
341
with FakeTensorMode():
342
x = torch.tensor(np.zeros([4, 4]))
343
self.checkType(x, "cpu", [4, 4])
345
def test_randperm(self):
346
x = torch.randperm(10)
347
y = torch.randperm(5, device="cpu")
348
with FakeTensorMode():
349
x1 = torch.randperm(10)
350
prims.utils.compare_tensor_meta(x, x1)
351
y1 = torch.randperm(5, device="cpu")
352
prims.utils.compare_tensor_meta(y, y1)
354
def test_print_in_fake_mode(self):
357
with FakeTensorMode():
359
assert "FakeTensor" not in out
361
@unittest.skipIf(not RUN_CUDA, "requires cuda")
362
def test_upsample_bilinear_small_channels(self):
364
mode = FakeTensorMode()
365
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
367
arg0_1 = torch.empty_strided((3, 427, 640), (1, 1920, 3), dtype=torch.float32, device='cuda')
368
unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
369
out.append(torch.ops.aten.upsample_bilinear2d.default(unsqueeze, [800, 1199], False))
371
self.assertTrue(out[1].is_contiguous())
372
self.checkMetaProps(out[0], out[1])
374
@unittest.skipIf(not RUN_CUDA, "requires cuda")
375
def test_cpu_fallback(self):
376
with FakeTensorMode(allow_fallback_kernels=False):
377
filters = torch.randn(8, 4, 3, 3).cuda()
378
inputs = torch.randn(1, 4, 5, 5).cuda()
379
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
380
self.assertEqual(out.device.type, "cuda")
381
self.assertEqual(list(out.size()), [1, 8, 5, 5])
383
with FakeTensorMode(allow_fallback_kernels=True):
385
filters = torch.randn(8, 20, 3, 3).cuda()
386
inputs = torch.randn(1, 7, 10, 5).cuda()
387
with self.assertRaises(RuntimeError):
388
torch.nn.functional.conv2d(inputs, filters, padding=1)
390
with FakeTensorMode(allow_fallback_kernels=True):
391
filters = torch.randn(8, 4, 3, 3).cuda()
392
inputs = torch.randn(1, 4, 5, 5).cuda()
394
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
395
self.assertEqual(out.device.type, "cuda")
396
self.assertEqual(list(out.size()), [1, 8, 5, 5])
398
@unittest.skipIf(not RUN_CUDA, "requires cuda")
399
def test_out_multi_device(self):
400
with FakeTensorMode():
402
y = torch.rand([4], device="cuda")
404
with self.assertRaisesRegex(Exception, "found two different devices"):
407
with self.assertRaisesRegex(Exception, "found two different devices"):
411
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
412
@unittest.skipIf(not RUN_CUDA, "requires cuda")
413
def test_normalize_device(self):
414
with FakeTensorMode():
415
x = torch.empty(1, device="cuda")
416
y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}")
418
self.checkType(out, "cuda", [1])
420
def test_recursive_invocation(self):
421
mode = FakeTensorMode()
424
mode.in_kernel_invocation = True
426
self.assertTrue(mode.in_kernel_invocation)
428
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
430
@parametrize("allow_fallback_kernels", [False, True],
431
lambda a: 'with_fallback' if a else 'without_fallback')
432
@unittest.skipIf(not RUN_CUDA, "requires cuda")
433
def test_cudnn_rnn(self, allow_fallback_kernels):
474
return torch.ops.aten._cudnn_rnn(
493
mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels)
494
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
497
torch.randn([92, 8, 2048]).cuda(),
498
torch.randn([8192, 2048]).cuda(),
499
torch.randn([8192, 2048]).cuda(),
500
torch.randn([8192]).cuda(),
501
torch.randn([8192]).cuda(),
502
torch.randn([8192, 2048]).cuda(),
503
torch.randn([8192, 2048]).cuda(),
504
torch.randn([8192]).cuda(),
505
torch.randn([8192]).cuda(),
506
torch.randn([8192, 4096]).cuda(),
507
torch.randn([8192, 2048]).cuda(),
508
torch.randn([8192]).cuda(),
509
torch.randn([8192]).cuda(),
510
torch.randn([8192, 4096]).cuda(),
511
torch.randn([8192, 2048]).cuda(),
512
torch.randn([8192]).cuda(),
513
torch.randn([8192]).cuda(),
514
torch.randn([167837696]).cuda(),
515
torch.randn([4, 8, 2048]).cuda(),
516
torch.randn([4, 8, 2048]).cuda(),
519
inps2[len(inps2) - 1] = None
521
for inps in [inps1, inps2]:
523
self.assertIs(out[4], inps[-3])
526
self.assertTrue(isinstance(ten, FakeTensor))
527
self.assertEqual(ten.device.type, 'cuda')
529
@unittest.skipIf(not RUN_CUDA, "requires cuda")
530
def test_cuda_lstm(self):
532
with torch.backends.cudnn.flags(enabled=False):
533
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
534
with fake_tensor_mode:
542
D = 2 if bidir else 1
543
H_out = proj_size if proj_size > 0 else hidden_size
545
lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
546
num_layers=num_layers, proj_size=proj_size, batch_first=False,
547
bias=True, bidirectional=bidir, device='cuda')
549
h_0 = torch.randn((num_layers * D, N, H_out), device='cuda')
550
c_0 = torch.randn((num_layers * D, N, hidden_size), device='cuda')
551
inp = torch.randn((L, N, H_in), device='cuda')
552
(output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
553
output.sum().backward()
555
self.assertEqual(output.shape, (L, N, D * H_out))
556
self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
557
self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
559
def test_data_dependent_operator(self):
560
with FakeTensorMode(allow_fallback_kernels=False):
561
x = torch.rand([10, 10])
563
self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
565
def test_tolist(self):
566
shape_env = ShapeEnv()
567
with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env):
571
def test_same_shape_env_preserved(self):
572
shape_env = ShapeEnv()
573
mode1 = FakeTensorMode(shape_env=shape_env)
574
t1 = mode1.from_tensor(
576
symbolic_context=StatelessSymbolicContext(
577
dynamic_sizes=[DimDynamic.DYNAMIC],
578
constraint_sizes=[None]
581
mode2 = FakeTensorMode(shape_env=shape_env)
582
t2 = mode2.from_tensor(t1)
584
self.assertIsNot(t2, t1)
585
self.assertIs(t1.fake_mode, mode1)
586
self.assertIs(t2.fake_mode, mode2)
587
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
588
self.assertEqual(str(t2.size(0)), str(t1.size(0)))
590
def test_jagged_fake_to_fake_preserved(self):
591
from torch.nested._internal.nested_tensor import jagged_from_list
595
a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64)
596
b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64)
597
c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64)
599
jt, _ = jagged_from_list([a, b, c], offsets)
600
shape_env = ShapeEnv()
601
mode1 = FakeTensorMode(shape_env=shape_env)
602
t1 = mode1.from_tensor(jt)
603
mode2 = FakeTensorMode(shape_env=shape_env)
604
t2 = mode2.from_tensor(t1)
607
self.assertTrue(free_symbols(t1.size()))
608
self.assertIsNot(t2, t1)
609
self.assertIs(t1.offsets().fake_mode, mode1)
610
self.assertIs(t2.offsets().fake_mode, mode2)
611
self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env)
612
self.assertEqual(str(t2.size(1)), str(t1.size(1)))
614
def checkMetaProps(self, t1, t2):
615
prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
618
def test_deepcopy(self):
619
with FakeTensorMode() as mode:
621
mod = torch.nn.BatchNorm2d(10)
622
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
623
mod_copied = copy.deepcopy(mod)
625
def check_copy(mod, mod_copied):
626
for name, param in itertools.chain(mod.named_parameters(), mod.named_buffers()):
627
param_copied = getattr(mod_copied, name)
628
self.checkMetaProps(param, param_copied)
629
self.assertTrue(isinstance(param_copied, FakeTensor))
630
self.assertEqual(isinstance(param, torch.nn.Parameter), isinstance(param_copied, torch.nn.Parameter))
631
self.assertEqual(param.requires_grad, param_copied.requires_grad)
633
check_copy(mod, mod_copied)
635
class ModuleNew(torch.nn.Module):
638
self.a = torch.rand([10, 2])
643
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
644
mod_copied = copy.deepcopy(mod)
646
self.assertIs(mod_copied.a, mod_copied.b)
647
self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
649
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
650
@unittest.skipIf(not RUN_CUDA, "requires cuda")
652
with FakeTensorMode():
653
a = torch.rand([16, 1])
654
self.checkType(a.new(10, 10), "cpu", [10, 10])
655
self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
656
b = torch.rand([4, 4], device='cuda')
657
self.checkType(b.new(device='cuda'), "cuda", [0])
658
self.checkType(a.new(torch.rand([1])), "cpu", [1])
660
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
661
def test_scalar_inputs(self):
662
with FakeTensorMode():
663
self.checkType(torch.div(3, 2), "cpu", [])
664
ten = torch.zeros(2, dtype=torch.int32) * 2.0
665
self.assertEqual(ten.dtype, torch.float)
666
self.checkType(ten, "cpu", [2])
668
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
669
def test_allow_meta(self):
671
with FakeTensorMode():
672
x = torch.rand([4], device="meta")
675
self.checkType(run_meta(), "meta", [4])
677
with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
678
self.assertRaises(Exception, run_meta)
680
def test_embedding_bag_meta(self):
684
embedding = torch.nn.EmbeddingBag(10, 3, mode='sum', device='meta')
685
input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
686
offsets = torch.tensor([0, 4], dtype=torch.long)
687
return embedding(input, offsets)
690
with FakeTensorMode():
693
for r, f in zip(real_out, fake_out):
694
self.assertEqual(r.size(), f.size())
695
self.assertEqual(r.device, f.device)
697
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
698
def test_mixed_real_and_fake_inputs(self):
699
class _TestPattern(torch.nn.Module):
702
self.conv = torch.nn.Conv2d(1, 1, 1)
703
self.bn = torch.nn.BatchNorm2d(1)
705
def forward(self, input):
706
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
707
scale_factor = self.bn.weight / running_std
708
weight_shape = [1] * len(self.conv.weight.shape)
710
bias_shape = [1] * len(self.conv.weight.shape)
712
scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
713
zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
714
conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
715
conv_orig = conv / scale_factor.reshape(bias_shape)
716
conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
717
conv = self.bn(conv_orig)
720
example_inputs = (torch.randn(1, 1, 3, 3),)
722
with FakeTensorMode(allow_non_fake_inputs=True):
723
out = mod(torch.randn(1, 1, 3, 3))
724
self.checkType(out, "cpu", (1, 1, 3, 3))
726
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
727
@unittest.skipIf(not RUN_CUDA, "requires cuda")
728
def test_aten_copy_multi_device(self):
729
with FakeTensorMode():
730
x1 = torch.rand(4, device="cpu")
731
x2 = torch.rand(4, device="cuda")
732
copy1 = torch.ops.aten.copy.default(x1, x2)
733
copy2 = torch.ops.aten.copy.default(x2, x1)
734
out = torch.empty(4, device="cpu")
735
torch.ops.aten.copy.out(x1, x2, out=out)
736
self.checkType(copy1, "cpu", (4,))
737
self.checkType(copy2, "cuda", (4,))
738
self.checkType(out, "cpu", (4,))
740
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
741
@unittest.skipIf(not RUN_CUDA, "requires cuda")
742
def test_aten_index_multi_device(self):
743
with FakeTensorMode():
744
x1 = torch.rand(4, 4, device="cpu")
745
x2 = torch.rand(4, 4, device="cuda")
746
i1 = torch.tensor([0, 1], device="cuda")
747
i2 = torch.tensor([0, 1], device="cpu")
748
r1 = torch.ops.aten.index(x1, i1)
749
r2 = torch.ops.aten.index(x2, i2)
751
y1 = torch.rand(4, device="cpu")
752
y2 = torch.rand(4, device="cuda")
753
j1 = torch.tensor([2], device="cuda")
754
j2 = torch.tensor([2], device="cpu")
755
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
756
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
757
self.checkType(r1, "cpu", ())
758
self.checkType(r2, "cuda", ())
759
self.checkType(r3, "cpu", (4, 4))
760
self.checkType(r4, "cuda", (4, 4))
762
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
763
@unittest.skipIf(not RUN_CUDA, "requires cuda")
764
def test_aten_slice_scatter_multi_device(self):
765
with FakeTensorMode():
766
x1 = torch.rand(4, 4, device="cpu")
767
y1 = torch.rand(2, 4, device="cuda")
768
x2 = torch.rand(4, 4, device="cuda")
769
y2 = torch.rand(2, 4, device="cpu")
770
out = torch.empty(4, 4, device="cpu")
771
r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
772
r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
773
r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
774
self.checkType(r1, "cpu", (4, 4))
775
self.checkType(r2, "cuda", (4, 4))
776
self.checkType(r3, "cpu", (4, 4))
777
self.checkType(out, "cpu", (4, 4))
779
def test__adaptive_avg_pool2d_backward(self):
780
with FakeTensorMode():
781
grad_out = torch.rand(2, 3, 4, 4)
782
inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
783
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
784
self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)
787
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
789
def test_export_numpy(self):
790
class MyNumpyModel(torch.nn.Module):
791
def forward(self, input):
792
input = input.numpy()
793
return input + np.random.randn(*input.shape)
795
with FakeTensorMode():
796
ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),))
797
self.assertTrue(isinstance(ep, torch.export.ExportedProgram))
800
class FakeTensorConstHandling(TestCase):
801
def assertConst(self, *args):
803
self.assertTrue(arg.constant is not None)
805
def assertNotConst(self, *args):
807
self.assertTrue(arg.constant is None)
809
def test_simple(self):
810
with FakeTensorMode():
812
self.assertEqual(x.item(), 4.)
814
def test_inplace_add(self):
815
with FakeTensorMode():
818
self.assertEqual(x.item(), 5.)
819
self.assertEqual(y.item(), 5.)
820
self.assertConst(x, y)
822
def test_shared_storages(self):
823
with FakeTensorMode():
824
x = torch.tensor([4.])
827
self.assertEqual(x.storage()._cdata, y.storage()._cdata)
828
self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
830
def test_constant_invalidation(self):
831
with FakeTensorMode():
832
x = torch.tensor([1.])
836
self.assertNotConst(x)
838
def test_inplace_view_invalidation(self):
839
with FakeTensorMode():
840
x = torch.tensor([1])
843
self.assertEqual(x.size(0), 2)
844
self.assertNotConst(x)
846
def test_fake_tensor_in_intlist_repro(self):
849
max_size = torch.tensor([800, 1216], dtype=torch.int64)
850
batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
851
return tensors[0].new_full(batch_shape, 0.0)
853
with self.assertRaises(torch._subclasses.fake_tensor.DataDependentOutputException):
854
with torch._subclasses.fake_tensor.FakeTensorMode():
855
a = torch.randn(3, 800, 1199)
856
b = torch.randn(3, 800, 800)
860
def test_fake_tensor_batch_norm_cpu(self):
861
with torch._subclasses.CrossRefFakeMode():
862
m = torch.nn.Sequential(
863
torch.nn.BatchNorm2d(10),
867
out = m(torch.randn([2, 10, 8, 8]))
869
def test_shared_storage_invalidation(self):
870
with FakeTensorMode():
871
x = torch.tensor([1.])
873
self.assertConst(x, y)
874
y.add_(torch.rand([1]))
875
self.assertNotConst(x, y)
877
def test_aliased_const_write(self):
878
with FakeTensorMode():
879
x = torch.tensor([1])
881
self.assertNotConst(y)
883
self.assertNotConst(x)
885
def test_constant_propagate_through_functions(self):
886
with FakeTensorMode():
887
y = torch.div(4, 4, rounding_mode='trunc')
890
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
891
return maybe_contained_type.isSubtypeOf(type) or any(
892
contains_type(e, maybe_contained_type) for e in type.containedTypes()
896
class FakeTensorOpInfoTest(TestCase):
897
@ops(custom_op_db, dtypes=OpDTypes.any_one)
898
def test_fake(self, device, dtype, op):
899
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
900
for sample_input in sample_inputs_itr:
901
args = (sample_input.input,) + sample_input.args
902
kwargs = sample_input.kwargs
903
optests.fake_check(op, args, kwargs)
906
class FakeTensorConverterTest(TestCase):
907
def test_memoized_conversion_to_meta(self):
908
x = torch.rand(2, 2, 2)
909
mode = FakeTensorMode()
910
self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
912
def test_memoized_conversion_from_meta(self):
913
x = torch.rand(2, 2).to(device="meta")
914
mode = FakeTensorMode()
915
converter = mode.fake_tensor_converter
916
self.assertTrue(converter.from_meta_and_device(mode, x, "cpu") is converter.from_meta_and_device(mode, x, "cpu"))
918
def test_separate_tensor_storages_view(self):
919
x = torch.rand(2, 2, 2)
921
mode = FakeTensorMode()
922
converter = mode.fake_tensor_converter
923
x_conv = converter(mode, x)
924
y_conv = converter(mode, y)
925
self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
927
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
928
def test_separate_tensor_storages_non_view(self):
929
x = torch.rand(2, 2, 2)
932
mode = FakeTensorMode()
933
converter = mode.fake_tensor_converter
934
x_conv = converter(mode, x)
935
y_conv = converter(mode, y)
936
stor_id = torch._C._storage_id(x_conv)
937
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
939
self.assertEqual(len(converter.tensor_memo), 1)
940
converter.meta_converter.check_for_expired_weak_storages()
941
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
943
self.assertEqual(len(converter.tensor_memo), 0)
944
converter.meta_converter.check_for_expired_weak_storages()
945
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
948
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
949
def test_dead_weak_ref(self):
950
x = torch.rand(2, 2, 2)
952
mode = FakeTensorMode()
953
converter = FakeTensorConverter()
954
x_conv = converter(mode, x)
955
x_conv_storage = torch._C._storage_id(x_conv)
957
self.assertFalse(x in converter.tensor_memo)
958
y_conv = converter(mode, y)
959
self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
961
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
962
def test_dead_key(self):
963
x = torch.rand(2, 2, 2)
964
mode = FakeTensorMode()
965
converter = FakeTensorConverter()
966
x_conv = converter(mode, x)
967
self.assertEqual(len(converter.tensor_memo), 1)
968
x_conv2 = converter(mode, x)
969
assert x_conv2 is x_conv
971
self.assertEqual(len(converter.tensor_memo), 0)
973
def test_no_active_mode(self):
974
with FakeTensorMode() as mode:
975
x = torch.empty(2, 2, device="cpu")
976
y = torch.empty(2, 2, device="cpu")
979
self.assertEqual(mode, out.fake_mode)
980
self.assertTrue(isinstance(out, FakeTensor))
981
self.assertEqual(out.device.type, "cpu")
983
def test_multiple_modes(self):
986
with FakeTensorMode() as m:
987
with FakeTensorMode() as m2:
988
t_fake = m.from_tensor(t)
989
t2_fake = m2.from_tensor(t2)
991
with self.assertRaisesRegex(Exception, "Mixing fake modes"):
994
def test_separate_mode_error(self):
995
with FakeTensorMode():
996
x = torch.empty(2, 2, device="cpu")
997
with FakeTensorMode():
998
y = torch.empty(2, 2, device="cpu")
999
self.assertRaises(Exception, lambda: x, y)
1001
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1002
def test_no_ref_cycle(self):
1004
mode = FakeTensorMode()
1005
y = mode.from_tensor(x)
1006
self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
1007
mode_weak = weakref.ref(mode)
1008
y_weak = weakref.ref(mode)
1011
assert mode_weak() is None
1012
assert y_weak() is None
1015
class FakeTensorOperatorInvariants(TestCase):
1017
def get_aten_op(schema):
1018
namespace, name = schema.name.split("::")
1019
overload = schema.overload_name if schema.overload_name else "default"
1020
assert namespace == "aten"
1021
return getattr(getattr(torch.ops.aten, name), overload)
1024
def get_all_aten_schemas():
1025
for schema in torch._C._jit_get_all_schemas():
1026
namespace = schema.name.split("::")[0]
1027
if namespace != "aten":
1031
def test_non_kwarg_only_device(self):
1032
for schema in self.get_all_aten_schemas():
1033
ten_type = torch._C.TensorType.get()
1035
contains_type(arg.type, ten_type)
1036
for arg in itertools.chain(schema.arguments, schema.returns)
1040
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1041
has_non_kwarg_device = any(
1042
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1043
for arg in schema.arguments
1045
if has_non_kwarg_device:
1047
self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
1050
def test_tensor_constructors_all_have_kwarg_device(self):
1051
for schema in self.get_all_aten_schemas():
1052
op = self.get_aten_op(schema)
1053
if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
1056
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1057
has_kwarg_device = any(
1058
arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1059
for arg in schema.arguments
1063
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
1066
@unittest.expectedFailure
1067
def test_sparse_new(self):
1068
with FakeTensorMode():
1069
indices = torch.randn(1, 1, dtype=torch.int64)
1070
values = torch.randn(1)
1072
sparse = torch.randn(1).to_sparse()
1075
sparse2 = sparse.new(indices, values, extra)
1077
def test_tensor_new(self):
1078
with FakeTensorMode():
1079
x = torch.Tensor([1, 2, 3])
1080
self.assertIsInstance(x, FakeTensor)
1082
def test_like_ops(self):
1083
for schema in self.get_all_aten_schemas():
1084
if "_like" == schema.name[-5:]:
1085
op = self.get_aten_op(schema)
1086
self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)
1088
def test_str_storage(self):
1090
with FakeTensorMode() as m:
1091
y = m.from_tensor(x)
1092
self.assertExpectedInline(str(x.storage()), '''\
1096
[torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 3]''')
1097
self.assertExpectedInline(str(y.storage()), '''\
1099
[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]''')
1101
self.assertExpectedInline(str(y.storage()), '''\
1103
[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]''')
1107
def test_embedding_bag_private(self):
1110
torch.ones(6, dtype=torch.int64),
1111
torch.arange(2, dtype=torch.int64),
1116
ref_out = torch.ops.aten._embedding_bag(*args)
1117
with FakeTensorMode() as m:
1118
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
1119
meta_out = torch.ops.aten._embedding_bag(*meta_args)
1121
self.assertEqual(len(ref_out), len(meta_out))
1122
for ref_o, meta_o in zip(ref_out, meta_out):
1123
self.assertEqual(ref_o.size(), meta_o.size())
1125
def test_cross_entropy_loss(self):
1126
inp = torch.randn(3, 5)
1127
target = torch.randint(5, (3,), dtype=torch.long)
1128
weight = torch.rand(5)
1129
fn = torch.nn.functional.cross_entropy
1130
for w in (weight, None):
1131
args = (inp, target, w)
1133
with FakeTensorMode() as m:
1134
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
1135
meta_out = torch.nn.functional.cross_entropy(*meta_args, label_smoothing=0.5)
1137
self.assertEqual(ref.size(), meta_out.size())
1140
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1141
def test_flash_attention(self):
1142
class Repro(torch.nn.Module):
1146
def forward(self, arg1, arg2, arg3):
1147
torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3, scale=0.17677669529663687)
1151
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1152
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1153
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1156
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1157
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1158
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1161
for args_list in args_new:
1162
args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
1163
(bsz, num_heads, seq_len, head_dim) in args_list]
1165
with torch._subclasses.CrossRefFakeMode():
1167
except RuntimeError as e:
1170
self.assertTrue("output[0]" not in str(e))
1171
self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
1174
@unittest.skipIf(not RUN_CUDA, "requires cuda")
1175
def test_conv_c1_backward(self):
1176
class Repro(torch.nn.Module):
1180
def forward(self, arg1, arg2, arg3):
1181
torch.ops.aten.convolution_backward.default(
1192
[True, True, False],
1196
((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "cuda"),
1197
((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "cuda"),
1198
((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "cuda"),
1200
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
1202
with torch._subclasses.CrossRefFakeMode():
1205
def test_no_dispatch_with_like_function(self):
1206
class CountingMode(TorchDispatchMode):
1210
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1212
return func(*args, **kwargs)
1214
with FakeTensorMode():
1216
with CountingMode() as mode:
1220
self.assertEqual(mode.count, 0)
1223
class FakeTensorPropTest(TestCase):
1224
def test_fake_tensor_prop_on_nn_module(self):
1225
class ToyNnModuleWithParameters(torch.nn.Module):
1228
self.layer1 = torch.nn.Linear(4, 3)
1229
self.layer2 = torch.nn.Linear(3, 2)
1231
def forward(self, value):
1232
value = self.layer1(value)
1233
value = torch.relu(value)
1234
value = self.layer2(value)
1237
model = ToyNnModuleWithParameters()
1238
value = torch.randn(5, 4)
1240
graph_model = torch.fx.symbolic_trace(model, (value,))
1245
with FakeTensorMode() as fake_tensor_mode:
1247
def to_fake_tensor(x):
1248
if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
1249
return fake_tensor_mode.from_tensor(x)
1252
fake_parameters_and_buffers = {
1253
k: to_fake_tensor(v)
1254
for k, v in itertools.chain(
1255
graph_model.named_parameters(), graph_model.named_buffers()
1258
with torch.nn.utils.stateless._reparametrize_module(
1259
graph_model, fake_parameters_and_buffers
1265
result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
1266
self.assertTrue(isinstance(result, FakeTensor))
1267
self.assertEqual(result.shape, (5, 2))
1274
FakeTensorProp(graph_model).propagate(value)
1275
except AssertionError:
1278
self.assertTrue(failed)
1281
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
1282
class OptionalArgumentInBetween(torch.nn.Module):
1285
self.layer1 = torch.nn.Linear(4, 3)
1286
self.layer2 = torch.nn.Linear(3, 2)
1288
def forward(self, value, another_value=None, another_optional_value=None):
1292
if another_value is None:
1293
another_value = torch.rand_like(value)
1294
if another_optional_value is None:
1295
another_optional_value = torch.rand_like(value)
1296
value = value + another_value + another_optional_value
1297
return value * value
1299
fake_mode = FakeTensorMode(allow_non_fake_inputs=True, allow_fallback_kernels=False)
1301
model = OptionalArgumentInBetween()
1302
value = torch.randn(5, 4)
1303
another_optional_value = torch.randn(5, 4)
1304
graph_model = torch.fx.symbolic_trace(model, (value, None, another_optional_value))
1305
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)
1308
def test_torch_load_with_fake_mode(self):
1310
class TheModelClass(torch.nn.Module):
1313
self.fc1 = torch.nn.Linear(5, 10)
1315
def forward(self, x):
1318
with TemporaryFileName() as state_dict_file:
1320
model = TheModelClass()
1321
torch.save(model.state_dict(), state_dict_file)
1323
fake_mode = FakeTensorMode()
1325
torch.load(state_dict_file)
1326
torch.load(state_dict_file, map_location="cpu")
1329
class FakeTensorSerialization(TestCase):
1330
def test_serialization(self):
1331
x = torch.tensor([0], device="cpu")
1332
with FakeTensorMode():
1333
y = pickle.loads(pickle.dumps(x))
1334
self.assertEqual(type(y), FakeTensor)
1335
self.assertEqual(y.device.type, "meta")
1337
with unset_fake_temporarily():
1338
y = pickle.loads(pickle.dumps(x))
1339
self.assertEqual(x.device, y.device)
1341
def test_serialization_with_tracing(self):
1342
x = torch.tensor([0], device="cpu")
1343
with tracing(TracingContext(FakeTensorMode())):
1344
y = pickle.loads(pickle.dumps(x))
1345
self.assertEqual(x.device, y.device)
1348
class FakeTensorDispatchCache(TestCase):
1349
def test_shape_env_settings(self):
1351
Validation that any boolean settings in ShapeEnv are present in the
1352
_ShapeEnvSettings. We hope to ensure that any new settings that might
1353
affect FakeTensor dispatch are included in the cache key calculation.
1354
If this test fails, consider updating _ShapeEnvSettings or change this
1355
test to omit checking for the new field.
1357
init_sig = inspect.signature(ShapeEnv._init)
1359
name for name, param in init_sig.parameters.items()
1360
if type(param.default) is bool
1363
settings = [f.name for f in dataclasses.fields(_ShapeEnvSettings)]
1365
self.assertTrue(arg in settings)
1367
def _test_cache_key(self, fm, x, y, z):
1369
Helper for all test_cache_key_* tests below. Assert that the
1370
cache keys for inputs x and y are the same, but z is different.
1372
func = aten.add.Tensor
1373
key_x = fm._cache_key(func, [x], {})
1374
key_y = fm._cache_key(func, [y], {})
1375
key_z = fm._cache_key(func, [z], {})
1377
self.assertEqual(key_x, key_y)
1378
self.assertNotEqual(key_x, key_z)
1380
def test_cache_key_dtype(self):
1381
with FakeTensorMode() as fm:
1382
x = torch.randn(4, 3, dtype=torch.float16)
1383
y = torch.randn(4, 3, dtype=torch.float16)
1384
z = x.to(dtype=torch.float32)
1385
self._test_cache_key(fm, x, y, z)
1387
def test_cache_key_shape(self):
1388
with FakeTensorMode() as fm:
1389
x = torch.randn(4, 3)
1390
y = torch.randn(4, 3)
1391
z = torch.randn(4, 2)
1392
self._test_cache_key(fm, x, y, z)
1394
def test_cache_key_stride(self):
1395
with FakeTensorMode() as fm:
1396
x = torch.randn(4, 2)
1397
y = torch.randn(4, 2)
1398
z = x.as_strided((4, 2), (1, 2))
1399
self._test_cache_key(fm, x, y, z)
1401
@unittest.skipIf(not RUN_CUDA, "requires cuda")
1402
def test_cache_key_device(self):
1403
with FakeTensorMode() as fm:
1404
x = torch.randn(4, 3)
1405
y = torch.randn(4, 3)
1406
z = x.to(device="cuda")
1407
self._test_cache_key(fm, x, y, z)
1409
def test_cache_key_memory_format(self):
1410
with FakeTensorMode() as fm:
1411
x = torch.randn(1, 2, 3, 4)
1412
y = torch.randn(1, 2, 3, 4)
1413
z = x.to(memory_format=torch.channels_last)
1414
self._test_cache_key(fm, x, y, z)
1416
def test_cache_key_storage_offset(self):
1417
with FakeTensorMode() as fm:
1418
x = torch.randn(3)[1:]
1419
y = torch.randn(3)[1:]
1421
self._test_cache_key(fm, x, y, z)
1423
def test_cache_key_requires_grad(self):
1424
with FakeTensorMode() as fm:
1425
x = torch.randn(4, 3)
1426
y = torch.randn(4, 3)
1427
z = torch.randn(4, 3, requires_grad=True)
1428
self._test_cache_key(fm, x, y, z)
1430
def test_cache_key_is_conj(self):
1431
with FakeTensorMode() as fm:
1432
x = torch.randn(4, 3, dtype=torch.complex64)
1433
y = torch.randn(4, 3, dtype=torch.complex64)
1434
z = torch.randn(4, 3, dtype=torch.complex64)
1435
torch._C._set_conj(z, not z.is_conj())
1436
self._test_cache_key(fm, x, y, z)
1438
def test_cache_key_is_neg(self):
1439
with FakeTensorMode() as fm:
1440
x = torch.randn(4, 3, dtype=torch.complex64)
1441
y = torch.randn(4, 3, dtype=torch.complex64)
1442
z = torch.randn(4, 3, dtype=torch.complex64)
1443
torch._C._set_neg(z, not z.is_neg())
1444
self._test_cache_key(fm, x, y, z)
1446
def test_cache_key_is_inference(self):
1447
with torch.inference_mode(True):
1448
t = torch.randn(4, 3)
1449
with FakeTensorMode() as fm:
1450
x = torch.randn(4, 3)
1451
y = torch.randn(4, 3)
1452
z = fm.from_tensor(t)
1453
self._test_cache_key(fm, x, y, z)
1455
def test_cache_key_constants(self):
1456
with FakeTensorMode() as fm:
1459
self._test_cache_key(fm, 1.0, 1.0, 1)
1460
self._test_cache_key(fm, 0.0, 0.0, 0)
1462
def assertHitsMisses(self, hits, misses):
1464
Helper to assert on the number of recorded hits and misses.
1466
info = FakeTensorMode.cache_info()
1467
self.assertEqual(info.hits, hits)
1468
self.assertEqual(info.misses, misses)
1470
def assertBypasses(self, reason, count):
1472
Helper to assert on the number of recorded bypasses.
1474
info = FakeTensorMode.cache_info()
1476
self.assertIn(reason, info.bypasses)
1477
self.assertEqual(info.bypasses[reason], count)
1479
self.assertNotIn(reason, info.bypasses)
1481
def test_cache_hit(self):
1483
Test that cache hit/miss counters are updated correctly.
1485
with FakeTensorMode():
1486
x = torch.randn(4, 3)
1487
y = torch.randn(4, 3)
1489
FakeTensorMode.cache_clear()
1490
self.assertHitsMisses(0, 0)
1492
self.assertHitsMisses(0, 1)
1494
self.assertHitsMisses(1, 1)
1497
extract_tensor_metadata(res1),
1498
extract_tensor_metadata(res2),
1501
def test_cache_bypass(self):
1503
Test that cache bypass counters are updated correctly.
1505
with FakeTensorMode():
1506
x = torch.randn(1, 2)
1508
FakeTensorMode.cache_clear()
1509
self.assertBypasses("inplace view", 0)
1512
self.assertBypasses("inplace view", 1)
1514
def test_cache_default_dtype(self):
1516
Test that the default dtype is respected when serving cached results.
1518
with FakeTensorMode():
1519
x = torch.tensor([1, 2], dtype=torch.int32)
1520
torch.set_default_dtype(torch.float32)
1522
FakeTensorMode.cache_clear()
1523
self.assertHitsMisses(0, 0)
1526
self.assertEqual(y.dtype, torch.float32)
1527
self.assertHitsMisses(0, 1)
1529
torch.set_default_dtype(torch.float16)
1531
self.assertEqual(y.dtype, torch.float16)
1532
self.assertHitsMisses(0, 2)
1534
torch.set_default_dtype(torch.float32)
1536
self.assertEqual(y.dtype, torch.float32)
1537
self.assertHitsMisses(1, 2)
1539
@unittest.skipIf(not RUN_CUDA, "requires cuda")
1540
def test_cache_default_device(self):
1542
Test that the default device is respected when serving cached results.
1544
with FakeTensorMode():
1545
FakeTensorMode.cache_clear()
1546
self.assertHitsMisses(0, 0)
1548
torch.set_default_device("cpu")
1549
x = torch.tensor([1, 2])
1551
self.assertEqual(y.device.type, "cpu")
1552
self.assertHitsMisses(0, 1)
1554
torch.set_default_device("cuda")
1555
x = torch.tensor([1, 2])
1557
self.assertEqual(y.device.type, "cuda")
1558
self.assertHitsMisses(0, 2)
1560
torch.set_default_device("cpu")
1561
x = torch.tensor([1, 2])
1563
self.assertEqual(y.device.type, "cpu")
1564
self.assertHitsMisses(1, 2)
1566
def test_cache_inplace_op(self):
1568
Test that inplace ops served from the cache correctly reference the
1571
with FakeTensorMode():
1572
x = torch.randn(1, 2)
1573
y = torch.randn(1, 2)
1575
FakeTensorMode.cache_clear()
1576
self.assertHitsMisses(0, 0)
1579
self.assertHitsMisses(0, 1)
1580
self.assertEqual(id(x), id(z))
1583
self.assertHitsMisses(1, 1)
1584
self.assertEqual(id(x), id(w))
1586
def test_cache_view_op(self):
1588
Test that view ops are handled correctly when served from the cache.
1590
with FakeTensorMode():
1591
x1 = torch.ones(2, requires_grad=True).clone()
1592
x2 = torch.ones(2, requires_grad=True).clone()
1598
self.assertFalse(z1._is_view())
1601
self.assertTrue(z2._is_view())
1606
self.assertTrue(z2._is_view())
1609
self.assertFalse(z1._is_view())
1611
def test_cache_dispatch_key_set(self):
1613
Test that operations that change the dispatch key set bypass caching.
1615
with FakeTensorMode():
1616
FakeTensorMode.cache_clear()
1617
self.assertBypasses("dispatch_key_set mismatch", 0)
1619
x = torch._efficientzerotensor(3)
1620
self.assertTrue(x._is_zerotensor())
1621
self.assertBypasses("dispatch_key_set mismatch", 1)
1623
y = torch._efficientzerotensor(3)
1624
self.assertTrue(y._is_zerotensor())
1625
self.assertBypasses("dispatch_key_set mismatch", 2)
1627
def test_inference_mode(self):
1629
Test that caching handles inference mode correctly.
1631
with FakeTensorMode():
1632
x = torch.randn(4, 3)
1633
y = torch.randn(4, 3)
1635
FakeTensorMode.cache_clear()
1636
self.assertHitsMisses(0, 0)
1640
with torch.inference_mode():
1643
self.assertHitsMisses(0, 2)
1644
self.assertFalse(res1.is_inference())
1645
self.assertTrue(res2.is_inference())
1650
self.assertHitsMisses(1, 2)
1651
self.assertFalse(res3.is_inference())
1653
extract_tensor_metadata(res1),
1654
extract_tensor_metadata(res3),
1657
with torch.inference_mode():
1660
self.assertHitsMisses(2, 2)
1661
self.assertTrue(res4.is_inference())
1663
extract_tensor_metadata(res2),
1664
extract_tensor_metadata(res4),
1668
instantiate_parametrized_tests(FakeTensorTest)
1670
only_for = ("cpu", "cuda")
1671
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for)
1673
if __name__ == "__main__":