pytorch

Форк
0
/
test_fake_tensor.py 
1914 строк · 68.3 Кб
1
# Owner(s): ["module: meta tensors"]
2

3

4
import contextlib
5
import copy
6
import dataclasses
7
import inspect
8
import itertools
9
import pickle
10
import unittest
11
import weakref
12
from unittest.mock import patch
13

14
import numpy as np
15
import torch
16
import torch._dynamo
17
import torch._functorch.config
18
import torch._prims as prims
19
import torch.testing._internal.optests as optests
20
import torch.utils._pytree as pytree
21

22
from torch import distributed as dist
23
from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
24
from torch._dynamo.testing import make_test_cls_with_patches, rand_strided
25
from torch._guards import tracing, TracingContext
26
from torch._subclasses.fake_tensor import (
27
    DynamicOutputShapeException,
28
    extract_tensor_metadata,
29
    FakeTensor,
30
    FakeTensorConverter,
31
    FakeTensorMode,
32
    unset_fake_temporarily,
33
    UnsupportedOperatorException,
34
    _CacheKeyState
35
)
36
from torch.fx.experimental.proxy_tensor import make_fx
37
from torch.fx.experimental.symbolic_shapes import (
38
    DimDynamic,
39
    free_symbols,
40
    ShapeEnv,
41
    ShapeEnvSettings,
42
    StatelessSymbolicContext,
43
    statically_known_true,
44
)
45
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
46
from torch.testing import FileCheck
47
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
48
from torch.testing._internal.common_device_type import (
49
    instantiate_device_type_tests,
50
    OpDTypes,
51
    ops,
52
)
53
from torch.testing._internal.common_utils import (
54
    instantiate_parametrized_tests,
55
    parametrize,
56
    run_tests,
57
    skipIfCrossRef,
58
    skipIfRocm,
59
    skipIfTorchDynamo,
60
    TemporaryFileName,
61
    TEST_WITH_TORCHDYNAMO,
62
    TestCase,
63
)
64

65
from torch.testing._internal.inductor_utils import GPU_TYPE
66
from torch.testing._internal.custom_op_db import custom_op_db
67
from torch.testing._internal.jit_utils import RUN_CUDA
68
from torch.utils._mode_utils import no_dispatch
69
from torch.utils._python_dispatch import TorchDispatchMode
70

71
aten = torch.ops.aten
72

73
torch._dynamo.config.fake_tensor_cache_enabled = True
74
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
75

76

77
def expectedFailurePropagateRealTensors(fn):
78
    fn._expected_failure_propagate_real_tensors = True
79
    return fn
80

81

82
class FakeTensorTest(TestCase):
83
    def checkType(self, t, device_str, size):
84
        self.assertTrue(isinstance(t, FakeTensor))
85
        self.assertEqual(t.device.type, device_str)
86
        self.assertEqual(list(t.size()), size)
87

88
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
89
    def test_cuda_initialized(self):
90
        # doesnt error
91
        with FakeTensorMode():
92
            p = torch.randn(4, 2, requires_grad=True, device="cuda")
93
            x = torch.randn(8, 4, device="cuda")
94
            y = torch.mm(x, p).square().sum()
95
            y.backward()
96

97
    def test_basic(self):
98
        x = torch.empty(2, 2, device="cpu")
99
        y = torch.empty(4, 2, 2, device="cpu")
100
        with FakeTensorMode() as mode:
101
            x = mode.from_tensor(x)
102
            y = mode.from_tensor(y)
103
            z = x + y
104
            self.assertEqual(z.shape, (4, 2, 2))
105
            self.assertEqual(z.device, torch.device("cpu"))
106
            self.assertTrue(isinstance(z, FakeTensor))
107

108
    def test_custom_op_fallback(self):
109
        from torch.library import impl, Library
110

111
        try:
112
            test_lib = Library("my_test_op", "DEF")  # noqa: TOR901
113
            test_lib.define("foo(Tensor self) -> Tensor")
114

115
            @impl(test_lib, "foo", "CPU")
116
            def foo_impl(self):
117
                return self.cos()
118

119
            x = torch.empty(2, 2, device="cpu")
120
            with self.assertRaisesRegex(
121
                UnsupportedOperatorException, "my_test_op.foo.default"
122
            ):
123
                with FakeTensorMode(allow_fallback_kernels=True) as mode:
124
                    x = mode.from_tensor(x)
125
                    torch.ops.my_test_op.foo(x)
126

127
        finally:
128
            test_lib._destroy()
129

130
    def test_parameter_instantiation(self):
131
        with FakeTensorMode():
132
            x = torch.rand([4])
133
            y = torch.nn.parameter.Parameter(x)
134
            self.assertTrue(isinstance(y, torch.nn.Parameter))
135

136
    @unittest.skipIf(not dist.is_available(), "requires distributed")
137
    def test_fsdp_flat_param(self):
138
        from torch.distributed.fsdp._flat_param import FlatParameter
139

140
        with FakeTensorMode() as m:
141
            data = torch.randn(2, 2)
142
            param = FlatParameter(data, requires_grad=True)
143
        self.assertIsInstance(param, FlatParameter)
144
        self.assertIsInstance(param, torch.nn.Parameter)
145
        self.assertIsInstance(param, FakeTensor)
146

147
    def test_non_parameter_grad(self):
148
        mode = FakeTensorMode()
149
        t = torch.rand([4], requires_grad=True)
150
        fake_t = mode.from_tensor(t)
151
        self.assertEqual(fake_t.requires_grad, t.requires_grad)
152

153
    @unittest.skipIf(
154
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
155
    )
156
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
157
    def test_index_cuda_with_cpu(self):
158
        with FakeTensorMode():
159
            x = torch.rand([2048], device="cuda")
160
            out = x[torch.zeros([36], dtype=torch.int64)]
161
            self.checkType(out, "cuda", [36])
162

163
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
164
    def test_shape_take_not_device(self):
165
        with FakeTensorMode():
166
            x = torch.empty(1, device="cpu")
167
            y = torch.empty(8, 8, device="cuda")
168
            out = x.resize_as_(y)
169
            self.assertEqual(out.shape, (8, 8))
170
            self.assertEqual(out.device.type, "cpu")
171
            self.assertTrue(isinstance(out, FakeTensor))
172

173
    def test_repr(self):
174
        with FakeTensorMode():
175
            x = torch.empty(2, 2, device="cpu")
176
            self.assertEqual(repr(x), "FakeTensor(..., size=(2, 2))")
177
            x = torch.empty(2, 2, device="meta")
178
            self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
179

180
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
181
    def test_zero_dim(self):
182
        with FakeTensorMode() as mode:
183
            x = torch.tensor(0.0)
184
            y = torch.rand([4, 4], device="cuda")
185
            out = x + y
186
            self.assertEqual(out.shape, (4, 4))
187
            self.assertEqual(out.device, y.device)
188
            self.assertTrue(isinstance(out, FakeTensor))
189

190
    def test_nan_to_num(self):
191
        with FakeTensorMode():
192
            for dtype in [torch.float16, torch.float32]:
193
                x = torch.rand([4], dtype=dtype)
194
                y = torch.nan_to_num(x, nan=None)
195
                z = torch.nan_to_num(x, 0.0)
196
                self.assertEqual(dtype, y.dtype)
197
                self.assertEqual(dtype, z.dtype)
198

199
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
200
    def test_throw(self):
201
        x = torch.tensor(0.0)  # TODO: tensor() errors
202
        with FakeTensorMode() as mode:
203
            x_conv = mode.from_tensor(x)
204
            y = torch.rand([4, 4], device="cuda")
205
            z = torch.rand([4, 4], device="cpu")
206
            self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
207

208
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
209
    def test_type_as(self):
210
        with FakeTensorMode():
211
            x = torch.rand([16, 1], device="cpu")
212
            y = torch.rand([4, 4], device="cuda")
213
            out = x.type_as(y)
214
            self.assertEqual(out.device.type, "cuda")
215
            self.assertTrue(isinstance(out, FakeTensor))
216

217
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
218
    def test_setitem(self):
219
        for device in ["cpu", "cuda"]:
220
            with FakeTensorMode():
221
                x = torch.rand([16, 1], device=device)
222
                x[..., 0] = 0
223

224
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
225
    def test_device_inplace_copy(self):
226
        with FakeTensorMode():
227
            x = torch.rand([8, 8], device="cpu")
228
            y = torch.rand([8, 8], device="cuda")
229
            assert x.copy_(y).device.type == "cpu"
230
            assert y.copy_(x).device.type == "cuda"
231

232
    def test_fake_dispatch_keys(self):
233
        with FakeTensorMode():
234
            x = torch.rand([4])
235
            f = (
236
                FileCheck()
237
                .check("CPU")
238
                .check("ADInplaceOrView")
239
                .check("AutogradCPU")
240
                .check("AutocastCPU")
241
            )
242
            f.run(torch._C._dispatch_key_set(x))
243

244
            with torch.inference_mode():
245
                x = torch.rand([4])
246
                y = x + x
247
                FileCheck().check("CPU").check("AutocastCPU").run(
248
                    torch._C._dispatch_key_set(y)
249
                )
250
                FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(
251
                    torch._C._dispatch_key_set(y)
252
                )
253

254
    def test_batch_tensor(self):
255
        x = torch.rand((3, 4, 5))
256
        b = _add_batch_dim(x, 0, 0)
257
        mode = FakeTensorMode()
258
        fake_b = mode.from_tensor(b)
259
        prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
260

261
        b1 = _add_batch_dim(x, 1, 1)
262
        b2 = _add_batch_dim(b1, 0, 2)
263
        fake_b2 = mode.from_tensor(b2)
264
        prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
265
        self.assertTrue(is_batchedtensor(fake_b2))
266
        fake_b1 = get_unwrapped(fake_b2)
267
        self.assertTrue(is_batchedtensor(fake_b1))
268
        fake_tensor = get_unwrapped(fake_b1)
269
        self.assertIsInstance(fake_tensor, FakeTensor)
270

271
    def test_constructor(self):
272
        with FakeTensorMode():
273
            x = torch.rand([4, 4], device="cpu")
274

275
        self.assertTrue(isinstance(x, FakeTensor))
276
        self.assertTrue(x.device.type == "cpu")
277

278
    def test_mode(self):
279
        with FakeTensorMode():
280
            y = torch.rand([4], device="cpu")
281
            out = y + y
282

283
        self.assertTrue(isinstance(out, FakeTensor))
284

285
    def test_full(self):
286
        # Test torch.full returns tensor with correct dtype
287
        with torch._subclasses.CrossRefFakeMode():
288
            y = torch.full((4, 4), 1)
289

290
    def check_function_with_fake(self, fn):
291
        out = fn()
292
        with torch._subclasses.FakeTensorMode():
293
            out_fake = fn()
294

295
        for a, b in zip(pytree.tree_leaves(out), pytree.tree_leaves(out_fake)):
296
            if not isinstance(a, torch.Tensor):
297
                self.assertTrue(not isinstance(b, torch.Tensor))
298
                continue
299

300
            prims.utils.compare_tensor_meta(a, b, check_strides=True)
301

302
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
303
    def test_non_kwarg_device(self):
304
        with FakeTensorMode():
305
            x = torch.rand([16, 1], device="cpu")
306
            y = x.to(torch.device("cpu"))
307
            self.assertIs(x, y)
308
            z = x.to(torch.device("cuda"))
309
            self.assertEqual(z.device.type, "cuda")
310

311
    def test_non_overlapping_stride_zero(self):
312
        def foo():
313
            x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
314
            return x.half()
315

316
        self.check_function_with_fake(foo)
317

318
    def test_fake_mode_error(self):
319
        x = torch.rand([4, 4])
320

321
        with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
322
            with FakeTensorMode():
323
                y = x[0]
324

325
    @unittest.skipIf(
326
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
327
    )
328
    def test_fake_grad_copy(self):
329
        x = torch.rand([4, 4], requires_grad=True)
330
        x.grad = torch.rand([4, 4])
331
        mode = FakeTensorMode()
332
        fake_x = mode.from_tensor(x)
333
        prims.utils.compare_tensor_meta(fake_x, x)
334
        prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
335

336
        self.assertTrue(isinstance(fake_x.grad, FakeTensor))
337

338
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
339
    def test_index_put_error(self):
340
        mode = FakeTensorMode()
341
        for context in [contextlib.nullcontext, lambda: mode]:
342
            with context():
343
                y = torch.randn(2, 2, 3)
344
                x = torch.randn(2, 2, 3).to("cuda")
345
                with self.assertRaises(RuntimeError):
346
                    x[[1, 1]] = y
347

348
                with self.assertRaises(RuntimeError):
349
                    torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), y)
350

351
                # no error
352
                torch.ops.aten.index_put(
353
                    x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.0)
354
                )
355
                torch.ops.aten.index_put_(
356
                    x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.0)
357
                )
358

359
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
360
    def test_like_constructor(self):
361
        with FakeTensorMode():
362
            x = torch.rand([4, 4])
363
            y = torch.ones_like(x)
364
            self.assertTrue(isinstance(y, FakeTensor))
365
            self.assertEqual(y.device.type, "cpu")
366
            z = torch.ones_like(x, device="cuda")
367
            self.assertTrue(isinstance(z, FakeTensor))
368
            self.assertEqual(z.device.type, "cuda")
369

370
    def test_binary_op_type_promotion(self):
371
        with FakeTensorMode():
372
            x = torch.empty([2, 2], dtype=torch.float)
373
            y = torch.empty([2, 2], dtype=torch.int64)
374
            out = x / y
375
            self.assertEqual(out.dtype, torch.float)
376
            self.assertEqual(out.device.type, "cpu")
377

378
    @unittest.skipIf(
379
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
380
    )
381
    def test_from_numpy(self):
382
        with FakeTensorMode():
383
            x = torch.tensor(np.zeros([4, 4]))
384
            self.checkType(x, "cpu", [4, 4])
385

386
    def test_randperm(self):
387
        x = torch.randperm(10)
388
        y = torch.randperm(5, device="cpu")
389
        with FakeTensorMode():
390
            x1 = torch.randperm(10)
391
            prims.utils.compare_tensor_meta(x, x1)
392
            y1 = torch.randperm(5, device="cpu")
393
            prims.utils.compare_tensor_meta(y, y1)
394

395
    def test_print_in_fake_mode(self):
396
        x = torch.zeros(2)
397
        # does not fail
398
        with FakeTensorMode():
399
            out = str(x)
400
        assert "FakeTensor" not in out
401

402
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
403
    def test_upsample_bilinear_small_channels(self):
404
        out = []
405
        mode = FakeTensorMode()
406
        for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
407
            with context():
408
                arg0_1 = torch.empty_strided(
409
                    (3, 427, 640), (1, 1920, 3), dtype=torch.float32, device="cuda"
410
                )
411
                unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
412
                out.append(
413
                    torch.ops.aten.upsample_bilinear2d.default(
414
                        unsqueeze, [800, 1199], False
415
                    )
416
                )
417

418
        self.assertTrue(out[1].is_contiguous())
419
        self.checkMetaProps(out[0], out[1])
420

421
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
422
    def test_cpu_fallback(self):
423
        with FakeTensorMode(allow_fallback_kernels=False):
424
            filters = torch.randn(8, 4, 3, 3).cuda()
425
            inputs = torch.randn(1, 4, 5, 5).cuda()
426
            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
427
            self.assertEqual(out.device.type, "cuda")
428
            self.assertEqual(list(out.size()), [1, 8, 5, 5])
429

430
        with FakeTensorMode(allow_fallback_kernels=True):
431
            # intentionally bad inputs
432
            filters = torch.randn(8, 20, 3, 3).cuda()
433
            inputs = torch.randn(1, 7, 10, 5).cuda()
434
            with self.assertRaises(RuntimeError):
435
                torch.nn.functional.conv2d(inputs, filters, padding=1)
436

437
        with FakeTensorMode(allow_fallback_kernels=True):
438
            filters = torch.randn(8, 4, 3, 3).cuda()
439
            inputs = torch.randn(1, 4, 5, 5).cuda()
440

441
            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
442
            self.assertEqual(out.device.type, "cuda")
443
            self.assertEqual(list(out.size()), [1, 8, 5, 5])
444

445
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
446
    def test_out_multi_device(self):
447
        with FakeTensorMode():
448
            x = torch.rand([4])
449
            y = torch.rand([4], device="cuda")
450

451
            with self.assertRaisesRegex(Exception, "found.+two.+devices"):
452
                torch.sin(x, out=y)
453

454
            with self.assertRaisesRegex(Exception, "found.+two.+devices"):
455
                x.add_(y)
456

457
    @unittest.skipIf(
458
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
459
    )
460
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
461
    def test_normalize_device(self):
462
        with FakeTensorMode():
463
            x = torch.empty(1, device="cuda")
464
            y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}")
465
            out = x + y
466
        self.checkType(out, "cuda", [1])
467

468
    def test_recursive_invocation(self):
469
        mode = FakeTensorMode()
470
        with mode:
471
            x = torch.tensor(2)
472
            mode.in_kernel_invocation = True
473
            y = x + x
474
            self.assertTrue(mode.in_kernel_invocation)
475

476
    @unittest.skipIf(
477
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
478
    )
479
    @skipIfRocm
480
    @parametrize(
481
        "allow_fallback_kernels",
482
        [False, True],
483
        lambda a: "with_fallback" if a else "without_fallback",
484
    )
485
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
486
    def test_cudnn_rnn(self, allow_fallback_kernels):
487
        def fn(
488
            a0,
489
            b0,
490
            b1,
491
            b2,
492
            b3,
493
            b4,
494
            b5,
495
            b6,
496
            b7,
497
            b8,
498
            b9,
499
            b10,
500
            b11,
501
            b12,
502
            b13,
503
            b14,
504
            b15,
505
            a3,
506
            a4,
507
            a5,
508
        ):
509
            a1 = [
510
                b0,
511
                b1,
512
                b2,
513
                b3,
514
                b4,
515
                b5,
516
                b6,
517
                b7,
518
                b8,
519
                b9,
520
                b10,
521
                b11,
522
                b12,
523
                b13,
524
                b14,
525
                b15,
526
            ]
527
            return torch.ops.aten._cudnn_rnn(
528
                a0,
529
                a1,
530
                4,
531
                a3,
532
                a4,
533
                a5,
534
                2,
535
                2048,
536
                0,
537
                2,
538
                False,
539
                0.0,
540
                False,
541
                True,
542
                [],
543
                None,
544
            )
545

546
        mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels)
547
        for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
548
            with context():
549
                inps1 = [
550
                    torch.randn([92, 8, 2048]).cuda(),
551
                    torch.randn([8192, 2048]).cuda(),
552
                    torch.randn([8192, 2048]).cuda(),
553
                    torch.randn([8192]).cuda(),
554
                    torch.randn([8192]).cuda(),
555
                    torch.randn([8192, 2048]).cuda(),
556
                    torch.randn([8192, 2048]).cuda(),
557
                    torch.randn([8192]).cuda(),
558
                    torch.randn([8192]).cuda(),
559
                    torch.randn([8192, 4096]).cuda(),
560
                    torch.randn([8192, 2048]).cuda(),
561
                    torch.randn([8192]).cuda(),
562
                    torch.randn([8192]).cuda(),
563
                    torch.randn([8192, 4096]).cuda(),
564
                    torch.randn([8192, 2048]).cuda(),
565
                    torch.randn([8192]).cuda(),
566
                    torch.randn([8192]).cuda(),
567
                    torch.randn([167837696]).cuda(),
568
                    torch.randn([4, 8, 2048]).cuda(),
569
                    torch.randn([4, 8, 2048]).cuda(),
570
                ]
571
                inps2 = inps1
572
                inps2[len(inps2) - 1] = None  # argument `cx` can be None
573

574
                for inps in [inps1, inps2]:
575
                    out = fn(*inps)
576
                    self.assertIs(out[4], inps[-3])
577
                    for ten in out:
578
                        if i == 1:
579
                            self.assertTrue(isinstance(ten, FakeTensor))
580
                        self.assertEqual(ten.device.type, "cuda")
581

582
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
583
    def test_cuda_lstm(self):
584
        # Ensure CUDA (non-cuDNN) impl succeeds with fake tensors.
585
        with torch.backends.cudnn.flags(enabled=False):
586
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
587
            with fake_tensor_mode:
588
                N = 5
589
                L = 4
590
                H_in = 2
591
                hidden_size = 3
592
                proj_size = 2
593
                num_layers = 2
594
                bidir = False
595
                D = 2 if bidir else 1
596
                H_out = proj_size if proj_size > 0 else hidden_size
597

598
                lstm = torch.nn.LSTM(
599
                    input_size=H_in,
600
                    hidden_size=hidden_size,
601
                    num_layers=num_layers,
602
                    proj_size=proj_size,
603
                    batch_first=False,
604
                    bias=True,
605
                    bidirectional=bidir,
606
                    device="cuda",
607
                )
608

609
                h_0 = torch.randn((num_layers * D, N, H_out), device="cuda")
610
                c_0 = torch.randn((num_layers * D, N, hidden_size), device="cuda")
611
                inp = torch.randn((L, N, H_in), device="cuda")
612
                (output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
613
                output.sum().backward()
614

615
                self.assertEqual(output.shape, (L, N, D * H_out))
616
                self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
617
                self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
618

619
    def test_data_dependent_operator(self):
620
        with FakeTensorMode(allow_fallback_kernels=False):
621
            x = torch.rand([10, 10])
622

623
            self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
624

625
    def test_parameter_view(self):
626
        x = torch.nn.Parameter(torch.randn(4))
627
        x_view = x.view(4)
628
        mode = FakeTensorMode()
629
        fake_x_view = mode.from_tensor(x_view)
630
        fake_x = mode.from_tensor(x)
631
        self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter))
632
        self.assertTrue(isinstance(fake_x, torch.nn.Parameter))
633

634
    def test_tolist(self):
635
        shape_env = ShapeEnv()
636
        with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env):
637
            x = torch.rand([10])
638
            x.tolist()
639

640
    # Propagate real tensors doesn't work with fake-on-fake
641
    @expectedFailurePropagateRealTensors
642
    def test_same_shape_env_preserved(self):
643
        shape_env = ShapeEnv()
644
        mode1 = FakeTensorMode(shape_env=shape_env)
645
        t1 = mode1.from_tensor(
646
            torch.randn(10),
647
            symbolic_context=StatelessSymbolicContext(
648
                dynamic_sizes=[DimDynamic.DYNAMIC], constraint_sizes=[None]
649
            ),
650
        )
651
        mode2 = FakeTensorMode(shape_env=shape_env)
652
        t2 = mode2.from_tensor(t1)
653
        # t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here
654
        self.assertIsNot(t2, t1)
655
        self.assertIs(t1.fake_mode, mode1)
656
        self.assertIs(t2.fake_mode, mode2)
657
        self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
658
        self.assertEqual(str(t2.size(0)), str(t1.size(0)))
659

660
    # TODO: Support NJT.  There's also some funny business with dynamic shapes
661
    # which would need to be dealt with as well
662
    @expectedFailurePropagateRealTensors
663
    def test_jagged_fake_to_fake_preserved(self):
664
        from torch.nested._internal.nested_tensor import jagged_from_list
665

666
        S0, S1, S2 = 3, 4, 5
667
        D = 4
668
        a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64)
669
        b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64)
670
        c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64)
671
        offsets = None
672
        jt, _ = jagged_from_list([a, b, c], offsets)
673
        shape_env = ShapeEnv()
674
        mode1 = FakeTensorMode(shape_env=shape_env)
675
        t1 = mode1.from_tensor(jt)
676
        mode2 = FakeTensorMode(shape_env=shape_env)
677
        t2 = mode2.from_tensor(t1)
678
        # It's not obvious that the invocation above makes it dynamic but it
679
        # does!
680
        self.assertTrue(free_symbols(t1.size()))
681
        self.assertIsNot(t2, t1)
682
        self.assertIs(t1.offsets().fake_mode, mode1)
683
        self.assertIs(t2.offsets().fake_mode, mode2)
684
        self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env)
685
        self.assertEqual(str(t2.size(1)), str(t1.size(1)))
686

687
    def checkMetaProps(self, t1, t2):
688
        prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
689

690
    @skipIfCrossRef
691
    def test_deepcopy(self):
692
        with FakeTensorMode() as mode:
693
            pass
694
        mod = torch.nn.BatchNorm2d(10)
695
        with torch._subclasses.fake_tensor.FakeCopyMode(mode):
696
            mod_copied = copy.deepcopy(mod)
697

698
        def check_copy(mod, mod_copied):
699
            for name, param in itertools.chain(
700
                mod.named_parameters(), mod.named_buffers()
701
            ):
702
                param_copied = getattr(mod_copied, name)
703
                self.checkMetaProps(param, param_copied)
704
                self.assertTrue(isinstance(param_copied, FakeTensor))
705
                self.assertEqual(
706
                    isinstance(param, torch.nn.Parameter),
707
                    isinstance(param_copied, torch.nn.Parameter),
708
                )
709
                self.assertEqual(param.requires_grad, param_copied.requires_grad)
710

711
        check_copy(mod, mod_copied)
712

713
        class ModuleNew(torch.nn.Module):
714
            def __init__(self) -> None:
715
                super().__init__()
716
                self.a = torch.rand([10, 2])
717
                self.b = self.a
718
                self.c = self.a[0]
719

720
        mod = ModuleNew()
721
        with torch._subclasses.fake_tensor.FakeCopyMode(mode):
722
            mod_copied = copy.deepcopy(mod)
723

724
        self.assertIs(mod_copied.a, mod_copied.b)
725
        self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
726

727
    @unittest.skipIf(
728
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
729
    )
730
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
731
    def test_new(self):
732
        with FakeTensorMode():
733
            a = torch.rand([16, 1])
734
            self.checkType(a.new(10, 10), "cpu", [10, 10])
735
            self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
736
            b = torch.rand([4, 4], device="cuda")
737
            self.checkType(b.new(device="cuda"), "cuda", [0])
738
            self.checkType(a.new(torch.rand([1])), "cpu", [1])
739

740
    @unittest.skipIf(
741
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
742
    )
743
    def test_scalar_inputs(self):
744
        with FakeTensorMode():
745
            self.checkType(torch.div(3, 2), "cpu", [])
746
            ten = torch.zeros(2, dtype=torch.int32) * 2.0
747
            self.assertEqual(ten.dtype, torch.float)
748
            self.checkType(ten, "cpu", [2])
749

750
    @unittest.skipIf(
751
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
752
    )
753
    def test_allow_meta(self):
754
        def run_meta():
755
            with FakeTensorMode():
756
                x = torch.rand([4], device="meta")
757
                return x + x
758

759
        self.checkType(run_meta(), "meta", [4])
760

761
        with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
762
            self.assertRaises(Exception, run_meta)
763

764
    def test_embedding_bag_meta(self):
765
        def f():
766
            # This behavior was originally unintentional but we see people
767
            # relying on it
768
            embedding = torch.nn.EmbeddingBag(10, 3, mode="sum", device="meta")
769
            input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
770
            offsets = torch.tensor([0, 4], dtype=torch.long)
771
            return embedding(input, offsets)
772

773
        real_out = f()
774
        with FakeTensorMode():
775
            fake_out = f()
776

777
        for r, f in zip(real_out, fake_out):
778
            self.assertEqual(r.size(), f.size())
779
            self.assertEqual(r.device, f.device)
780

781
    @unittest.skipIf(
782
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
783
    )
784
    def test_mixed_real_and_fake_inputs(self):
785
        class _TestPattern(torch.nn.Module):
786
            def __init__(self) -> None:
787
                super().__init__()
788
                self.conv = torch.nn.Conv2d(1, 1, 1)
789
                self.bn = torch.nn.BatchNorm2d(1)
790

791
            def forward(self, input):
792
                running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
793
                scale_factor = self.bn.weight / running_std
794
                weight_shape = [1] * len(self.conv.weight.shape)
795
                weight_shape[0] = -1
796
                bias_shape = [1] * len(self.conv.weight.shape)
797
                bias_shape[1] = -1
798
                scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
799
                zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
800
                conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
801
                conv_orig = conv / scale_factor.reshape(bias_shape)
802
                conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
803
                conv = self.bn(conv_orig)
804
                return conv
805

806
        example_inputs = (torch.randn(1, 1, 3, 3),)
807
        mod = _TestPattern()
808
        with FakeTensorMode(allow_non_fake_inputs=True):
809
            out = mod(torch.randn(1, 1, 3, 3))
810
        self.checkType(out, "cpu", (1, 1, 3, 3))
811

812
    @unittest.skipIf(
813
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
814
    )
815
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
816
    def test_aten_copy_multi_device(self):
817
        with FakeTensorMode():
818
            x1 = torch.rand(4, device="cpu")
819
            x2 = torch.rand(4, device="cuda")
820
            copy1 = torch.ops.aten.copy.default(x1, x2)
821
            copy2 = torch.ops.aten.copy.default(x2, x1)
822
            out = torch.empty(4, device="cpu")
823
            torch.ops.aten.copy.out(x1, x2, out=out)
824
        self.checkType(copy1, "cpu", (4,))
825
        self.checkType(copy2, "cuda", (4,))
826
        self.checkType(out, "cpu", (4,))
827

828
    @unittest.skipIf(
829
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
830
    )
831
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
832
    def test_aten_index_multi_device(self):
833
        with FakeTensorMode():
834
            x1 = torch.rand(4, 4, device="cpu")
835
            x2 = torch.rand(4, 4, device="cuda")
836
            i1 = torch.tensor([0, 1], device="cuda")
837
            i2 = torch.tensor([0, 1], device="cpu")
838
            # NB: This one does not work: cuda indices not allowed on cpu
839
            # tensor
840
            # r1 = torch.ops.aten.index(x1, i1)
841
            r2 = torch.ops.aten.index(x2, i2)
842

843
            y1 = torch.rand(4, device="cpu")
844
            y2 = torch.rand(4, device="cuda")
845
            j1 = torch.tensor([2], device="cuda")
846
            j2 = torch.tensor([2], device="cpu")
847
            r3 = torch.ops.aten.index_put.default(x1, j1, y1)
848
            r4 = torch.ops.aten.index_put.default(x2, j2, y2)
849
        # self.checkType(r1, "cpu", ())
850
        self.checkType(r2, "cuda", ())
851
        self.checkType(r3, "cpu", (4, 4))
852
        self.checkType(r4, "cuda", (4, 4))
853

854
    @unittest.skipIf(
855
        TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
856
    )
857
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
858
    def test_aten_slice_scatter_multi_device(self):
859
        with FakeTensorMode():
860
            x1 = torch.rand(4, 4, device="cpu")
861
            y1 = torch.rand(2, 4, device="cuda")
862
            x2 = torch.rand(4, 4, device="cuda")
863
            y2 = torch.rand(2, 4, device="cpu")
864
            out = torch.empty(4, 4, device="cpu")
865
            r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
866
            r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
867
            r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
868
        self.checkType(r1, "cpu", (4, 4))
869
        self.checkType(r2, "cuda", (4, 4))
870
        self.checkType(r3, "cpu", (4, 4))
871
        self.checkType(out, "cpu", (4, 4))
872

873
    def test__adaptive_avg_pool2d_backward(self):
874
        with FakeTensorMode():
875
            grad_out = torch.rand(2, 3, 4, 4)
876
            inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
877
            grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
878
            self.assertTrue(
879
                torch._prims_common.suggest_memory_format(grad_in)
880
                == torch.channels_last
881
            )
882

883
    def test_export_numpy(self):
884
        class MyNumpyModel(torch.nn.Module):
885
            def forward(self, input):
886
                input = input.numpy()
887
                return input + np.random.randn(*input.shape)
888

889
        with FakeTensorMode():
890
            ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),))
891
            self.assertTrue(isinstance(ep, torch.export.ExportedProgram))
892

893
    def test_unsqueeze_copy(self):
894
        shape_env = ShapeEnv()
895
        t1 = torch.ones(2, 2, 768)
896
        with FakeTensorMode(shape_env=shape_env) as fake_mode:
897
            t = fake_mode.from_tensor(
898
                t1,
899
                symbolic_context=StatelessSymbolicContext(
900
                    dynamic_sizes=[
901
                        DimDynamic.DYNAMIC,
902
                        DimDynamic.STATIC,
903
                        DimDynamic.STATIC,
904
                    ],
905
                ),
906
            )
907

908
        self.assertEqual(t.shape[0], torch.ops.aten.unsqueeze_copy(t, 1).shape[0])
909

910
    def test_alias_call(self):
911
        fwAD = torch.autograd.forward_ad
912

913
        def f(x):
914
            return 4312491 * x
915

916
        with torch._subclasses.fake_tensor.FakeTensorMode():
917
            with fwAD.dual_level():
918
                x = torch.randn(3, device="cpu")
919
                y = torch.ones_like(x)
920
                dual = fwAD.make_dual(x, y)
921
                r = f(dual)
922

923
        self.assertIsInstance(r, FakeTensor)
924
        self.assertEqual(r.size(), [3])
925

926

927
instantiate_parametrized_tests(FakeTensorTest)
928

929

930
def make_propagate_real_tensors_cls(cls):
931
    cls = make_test_cls_with_patches(
932
        cls,
933
        "PropagateRealTensors",
934
        "_propagate_real_tensors",
935
        (torch._functorch.config, "fake_tensor_propagate_real_tensors", True),
936
        xfail_prop="_expected_failure_propagate_real_tensors",
937
        decorator=skipIfTorchDynamo("propagate_real_tensors affects Dynamo"),
938
    )
939
    cls.__file__ = __file__
940
    cls.__module__ = __name__
941
    globals()[cls.__name__] = cls
942

943

944
make_propagate_real_tensors_cls(FakeTensorTest)
945

946

947
class FakeTensorConstHandling(TestCase):
948
    def assertConst(self, *args):
949
        for arg in args:
950
            self.assertTrue(arg.constant is not None)
951

952
    def assertNotConst(self, *args):
953
        for arg in args:
954
            self.assertTrue(arg.constant is None)
955

956
    def test_simple(self):
957
        with FakeTensorMode():
958
            x = torch.tensor(4.0)
959
            self.assertEqual(x.item(), 4.0)
960

961
    def test_inplace_add(self):
962
        with FakeTensorMode():
963
            x = torch.tensor(4.0)
964
            y = x.add_(1)
965
            self.assertEqual(x.item(), 5.0)
966
            self.assertEqual(y.item(), 5.0)
967
            self.assertConst(x, y)
968

969
    def test_shared_storages(self):
970
        with FakeTensorMode():
971
            x = torch.tensor([4.0])
972
            y = x[:]
973

974
            self.assertEqual(x.storage()._cdata, y.storage()._cdata)
975
            self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
976

977
    def test_constant_invalidation(self):
978
        with FakeTensorMode():
979
            x = torch.tensor([1.0])
980
            self.assertConst(x)
981
            y = torch.rand([1])
982
            x.add_(y)
983
            self.assertNotConst(x)
984

985
    def test_inplace_view_invalidation(self):
986
        with FakeTensorMode():
987
            x = torch.tensor([1])
988
            self.assertConst(x)
989
            x.resize_([2])
990
            self.assertEqual(x.size(0), 2)
991
            self.assertNotConst(x)
992

993
    def test_fake_tensor_in_intlist_repro(self):
994
        def fn(tensors):
995
            max_size = torch.tensor([800, 1216], dtype=torch.int64)
996
            batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
997
            return tensors[0].new_full(batch_shape, 0.0)
998

999
        with self.assertRaises(
1000
            torch._subclasses.fake_tensor.DataDependentOutputException
1001
        ):
1002
            with torch._subclasses.fake_tensor.FakeTensorMode():
1003
                a = torch.randn(3, 800, 1199)
1004
                b = torch.randn(3, 800, 800)
1005
                inputs = [a, b]
1006
                ref = fn(inputs)
1007

1008
    def test_fake_tensor_batch_norm_cpu(self):
1009
        with torch._subclasses.CrossRefFakeMode():
1010
            m = torch.nn.Sequential(
1011
                torch.nn.BatchNorm2d(10),
1012
                torch.nn.ReLU(),
1013
            )
1014
            m.eval()
1015
            out = m(torch.randn([2, 10, 8, 8]))
1016

1017
    def test_shared_storage_invalidation(self):
1018
        with FakeTensorMode():
1019
            x = torch.tensor([1.0])
1020
            y = x[:]
1021
            self.assertConst(x, y)
1022
            y.add_(torch.rand([1]))
1023
            self.assertNotConst(x, y)
1024

1025
    def test_aliased_const_write(self):
1026
        with FakeTensorMode():
1027
            x = torch.tensor([1])
1028
            y = x.expand([4])
1029
            self.assertNotConst(y)
1030
            y[0] = 1
1031
            self.assertNotConst(x)
1032

1033
    def test_constant_propagate_through_functions(self):
1034
        with FakeTensorMode():
1035
            y = torch.div(4, 4, rounding_mode="trunc")
1036
            self.assertConst(y)
1037

1038

1039
make_propagate_real_tensors_cls(FakeTensorConstHandling)
1040

1041

1042
def contains_type(type: torch.Type, maybe_contained_type: torch.Type):
1043
    return maybe_contained_type.isSubtypeOf(type) or any(
1044
        contains_type(e, maybe_contained_type) for e in type.containedTypes()
1045
    )
1046

1047

1048
class FakeTensorOpInfoTest(TestCase):
1049
    @ops(custom_op_db, dtypes=OpDTypes.any_one)
1050
    def test_fake(self, device, dtype, op):
1051
        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1052
        for sample_input in sample_inputs_itr:
1053
            args = (sample_input.input,) + sample_input.args
1054
            kwargs = sample_input.kwargs
1055
            optests.fake_check(op, args, kwargs)
1056

1057

1058
make_propagate_real_tensors_cls(FakeTensorOpInfoTest)
1059
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))
1060
instantiate_device_type_tests(
1061
    PropagateRealTensorsFakeTensorOpInfoTest, globals(), only_for=("cpu",)  # noqa: F821
1062
)
1063

1064

1065
class FakeTensorConverterTest(TestCase):
1066
    def test_memoized_conversion_to_meta(self):
1067
        x = torch.rand(2, 2, 2)
1068
        mode = FakeTensorMode()
1069
        self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
1070

1071
    def test_memoized_conversion_from_meta(self):
1072
        x = torch.rand(2, 2).to(device="meta")
1073
        mode = FakeTensorMode()
1074
        converter = mode.fake_tensor_converter
1075
        self.assertTrue(
1076
            converter.from_meta_and_device(mode, x, "cpu")
1077
            is converter.from_meta_and_device(mode, x, "cpu")
1078
        )
1079

1080
    def test_separate_tensor_storages_view(self):
1081
        x = torch.rand(2, 2, 2)
1082
        y = x[0]
1083
        mode = FakeTensorMode()
1084
        converter = mode.fake_tensor_converter
1085
        x_conv = converter.from_real_tensor(mode, x)
1086
        y_conv = converter.from_real_tensor(mode, y)
1087
        self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
1088

1089
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1090
    def test_separate_tensor_storages_non_view(self):
1091
        x = torch.rand(2, 2, 2)
1092
        y = torch.rand(4, 2)
1093
        y.set_(x.storage())
1094
        mode = FakeTensorMode()
1095
        converter = mode.fake_tensor_converter
1096
        x_conv = converter.from_real_tensor(mode, x)
1097
        y_conv = converter.from_real_tensor(mode, y)
1098
        stor_id = torch._C._storage_id(x_conv)
1099
        self.assertEqual(stor_id, torch._C._storage_id(y_conv))
1100
        del x
1101
        del x_conv
1102
        self.assertEqual(len(converter.tensor_memo), 1)
1103
        self.assertEqual(len(converter.meta_converter.storage_memo), 1)
1104
        del y
1105
        del y_conv
1106
        self.assertEqual(len(converter.tensor_memo), 0)
1107
        self.assertEqual(len(converter.meta_converter.storage_memo), 0)
1108

1109
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1110
    def test_dead_weak_ref(self):
1111
        x = torch.rand(2, 2, 2)
1112
        y = x[0]
1113
        mode = FakeTensorMode()
1114
        converter = FakeTensorConverter()
1115
        x_conv = converter.from_real_tensor(mode, x)
1116
        x_conv_storage = x_conv.untyped_storage()
1117
        del x_conv
1118
        self.assertFalse(x in converter.tensor_memo)
1119
        y_conv = converter.from_real_tensor(mode, y)
1120
        self.assertIs(x_conv_storage, y_conv.untyped_storage())
1121

1122
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1123
    def test_dead_key(self):
1124
        x = torch.rand(2, 2, 2)
1125
        mode = FakeTensorMode()
1126
        converter = FakeTensorConverter()
1127
        x_conv = converter.from_real_tensor(mode, x)
1128
        self.assertEqual(len(converter.tensor_memo), 1)
1129
        x_conv2 = converter.from_real_tensor(mode, x)
1130
        assert x_conv2 is x_conv
1131
        del x
1132
        del x_conv
1133
        del x_conv2
1134
        self.assertEqual(len(converter.tensor_memo), 0)
1135

1136
    def test_no_active_mode(self):
1137
        with FakeTensorMode() as mode:
1138
            x = torch.empty(2, 2, device="cpu")
1139
            y = torch.empty(2, 2, device="cpu")
1140

1141
        out = x + y
1142
        self.assertEqual(mode, out.fake_mode)
1143
        self.assertTrue(isinstance(out, FakeTensor))
1144
        self.assertEqual(out.device.type, "cpu")
1145

1146
    def test_multiple_modes(self):
1147
        t = torch.rand([4])
1148
        t2 = torch.rand([4])
1149
        with FakeTensorMode() as m:
1150
            with FakeTensorMode() as m2:
1151
                t_fake = m.from_tensor(t)
1152
                t2_fake = m2.from_tensor(t2)
1153

1154
                with self.assertRaisesRegex(Exception, "Mixing fake modes"):
1155
                    t_fake + t2_fake
1156

1157
    def test_separate_mode_error(self):
1158
        with FakeTensorMode():
1159
            x = torch.empty(2, 2, device="cpu")
1160
        with FakeTensorMode():
1161
            y = torch.empty(2, 2, device="cpu")
1162
        self.assertRaises(Exception, lambda: x, y)
1163

1164
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1165
    def test_no_ref_cycle(self):
1166
        x = torch.rand([4])
1167
        mode = FakeTensorMode()
1168
        y = mode.from_tensor(x)
1169
        self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
1170
        mode_weak = weakref.ref(mode)
1171
        y_weak = weakref.ref(mode)
1172
        del mode
1173
        del y
1174
        assert mode_weak() is None
1175
        assert y_weak() is None
1176

1177

1178
make_propagate_real_tensors_cls(FakeTensorConverterTest)
1179

1180

1181
class FakeTensorOperatorInvariants(TestCase):
1182
    def get_aten_op(self, schema):
1183
        namespace, name = schema.name.split("::")
1184
        overload = schema.overload_name if schema.overload_name else "default"
1185
        assert namespace == "aten"
1186
        return getattr(getattr(torch.ops.aten, name), overload)
1187

1188
    def get_all_aten_schemas(self):
1189
        for schema in torch._C._jit_get_all_schemas():
1190
            namespace = schema.name.split("::")[0]
1191
            if namespace != "aten":
1192
                continue
1193
            yield schema
1194

1195
    def test_non_kwarg_only_device(self):
1196
        for schema in self.get_all_aten_schemas():
1197
            ten_type = torch._C.TensorType.get()
1198
            if not any(
1199
                contains_type(arg.type, ten_type)
1200
                for arg in itertools.chain(schema.arguments, schema.returns)
1201
            ):
1202
                continue
1203

1204
            opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1205
            has_non_kwarg_device = any(
1206
                not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1207
                for arg in schema.arguments
1208
            )
1209
            if has_non_kwarg_device:
1210
                self.assertTrue(
1211
                    self.get_aten_op(schema)
1212
                    in torch._subclasses.fake_tensor._device_not_kwarg_ops
1213
                )
1214

1215
    def test_tensor_constructors_all_have_kwarg_device(self):
1216
        for schema in self.get_all_aten_schemas():
1217
            op = self.get_aten_op(schema)
1218
            if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
1219
                continue
1220

1221
            opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1222
            has_kwarg_device = any(
1223
                arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1224
                for arg in schema.arguments
1225
            )
1226

1227
            self.assertTrue(
1228
                has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
1229
            )
1230

1231
    @unittest.expectedFailure
1232
    def test_sparse_new(self):
1233
        with FakeTensorMode():
1234
            indices = torch.randn(1, 1, dtype=torch.int64)
1235
            values = torch.randn(1)
1236
            extra = (2,)
1237
            sparse = torch.randn(1).to_sparse()
1238
            # This used to segfault, now it does not, but it still raises an
1239
            # error
1240
            sparse2 = sparse.new(indices, values, extra)
1241

1242
    def test_tensor_new(self):
1243
        with FakeTensorMode():
1244
            x = torch.Tensor([1, 2, 3])
1245
        self.assertIsInstance(x, FakeTensor)
1246

1247
    def test_like_ops(self):
1248
        for schema in self.get_all_aten_schemas():
1249
            if "_like" == schema.name[-5:]:
1250
                op = self.get_aten_op(schema)
1251
                self.assertIn(
1252
                    op, torch._subclasses.fake_tensor._like_tensor_constructors
1253
                )
1254

1255
    def test_str_storage(self):
1256
        x = torch.zeros(3)
1257
        with FakeTensorMode() as m:
1258
            y = m.from_tensor(x)
1259
            self.assertExpectedInline(
1260
                str(x.storage()),
1261
                """\
1262
 0.0
1263
 0.0
1264
 0.0
1265
[torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 3]""",
1266
            )
1267
            self.assertExpectedInline(
1268
                str(y.storage()),
1269
                """\
1270
...
1271
[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""",
1272
            )
1273

1274
        self.assertExpectedInline(
1275
            str(y.storage()),
1276
            """\
1277
...
1278
[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""",
1279
        )
1280

1281
    # at::_embedding_bag has no op info,
1282
    # and returns extra tensors that at::embedding bag throws away
1283
    def test_embedding_bag_private(self):
1284
        args = [
1285
            torch.ones(6, 1),
1286
            torch.ones(6, dtype=torch.int64),
1287
            torch.arange(2, dtype=torch.int64),
1288
            False,
1289
            2,  # mode = max
1290
        ]
1291

1292
        ref_out = torch.ops.aten._embedding_bag(*args)
1293
        with FakeTensorMode() as m:
1294
            meta_args = [
1295
                m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
1296
            ]
1297
            meta_out = torch.ops.aten._embedding_bag(*meta_args)
1298

1299
        self.assertEqual(len(ref_out), len(meta_out))
1300
        for ref_o, meta_o in zip(ref_out, meta_out):
1301
            self.assertEqual(ref_o.size(), meta_o.size())
1302

1303
    def test_cross_entropy_loss(self):
1304
        inp = torch.randn(3, 5)
1305
        target = torch.randint(5, (3,), dtype=torch.long)
1306
        weight = torch.rand(5)
1307
        fn = torch.nn.functional.cross_entropy
1308
        for w in (weight, None):
1309
            args = (inp, target, w)
1310
            ref = fn(*args)
1311
            with FakeTensorMode() as m:
1312
                meta_args = [
1313
                    m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
1314
                ]
1315
                meta_out = torch.nn.functional.cross_entropy(
1316
                    *meta_args, label_smoothing=0.5
1317
                )
1318

1319
            self.assertEqual(ref.size(), meta_out.size())
1320

1321
    @skipIfRocm
1322
    @unittest.skipIf(
1323
        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
1324
        "Does not support SDPA or pre-SM80 hardware",
1325
    )
1326
    def test_flash_attention(self):
1327
        class Repro(torch.nn.Module):
1328
            def __init__(self) -> None:
1329
                super().__init__()
1330

1331
            def forward(self, arg1, arg2, arg3):
1332
                torch.ops.aten._scaled_dot_product_flash_attention(
1333
                    arg1, arg2, arg3, scale=0.17677669529663687
1334
                )
1335

1336
        args_new = [
1337
            [
1338
                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1339
                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1340
                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1341
            ],
1342
            [
1343
                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1344
                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1345
                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1346
            ],
1347
        ]
1348
        for args_list in args_new:
1349
            args = [
1350
                rand_strided(bsz, num_heads, seq_len, head_dim)
1351
                for (bsz, num_heads, seq_len, head_dim) in args_list
1352
            ]
1353
            try:
1354
                with torch._subclasses.CrossRefFakeMode():
1355
                    Repro()(*args)
1356
            except RuntimeError as e:
1357
                # We expect the cross ref to succed for the first output to fail
1358
                # for the rng state, see Note [Seed and Offset]
1359
                self.assertTrue("output[0]" not in str(e))
1360
                self.assertTrue(
1361
                    "found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!"
1362
                    in str(e)
1363
                )
1364

1365
    # IMPORTANT!!! Always run even if CUDA is not available
1366
    def test_fake_gpu_no_init(self):
1367
        # Skip this test, we will try to run CUDA operations to real prop so
1368
        # it clearly will not work on CPU runner
1369
        if torch._functorch.config.fake_tensor_propagate_real_tensors:
1370
            return
1371
        with FakeTensorMode():
1372
            torch.empty(10, device=GPU_TYPE)
1373
            torch.ones(10, device=GPU_TYPE)
1374
            torch.zeros(10, device=GPU_TYPE)
1375
            torch.rand(10, device=GPU_TYPE)
1376
            torch.tensor(3.14, device=GPU_TYPE)
1377
            torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE)
1378

1379
    @skipIfRocm
1380
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1381
    def test_conv_c1_backward(self):
1382
        class Repro(torch.nn.Module):
1383
            def __init__(self) -> None:
1384
                super().__init__()
1385

1386
            def forward(self, arg1, arg2, arg3):
1387
                torch.ops.aten.convolution_backward.default(
1388
                    arg1,
1389
                    arg2,
1390
                    arg3,
1391
                    [1],
1392
                    [1, 1],
1393
                    [1, 1],
1394
                    [1, 1],
1395
                    False,
1396
                    [0, 0],
1397
                    1,
1398
                    [True, True, False],
1399
                )
1400

1401
        args_new = [
1402
            ((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "cuda"),
1403
            ((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "cuda"),
1404
            ((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "cuda"),
1405
        ]
1406
        args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
1407

1408
        with torch._subclasses.CrossRefFakeMode():
1409
            Repro()(*args)
1410

1411
    def test_no_dispatch_with_like_function(self):
1412
        class CountingMode(TorchDispatchMode):
1413
            def __init__(self) -> None:
1414
                self.count = 0
1415

1416
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1417
                self.count += 1
1418
                return func(*args, **kwargs)
1419

1420
        with FakeTensorMode():
1421
            x = torch.randn(2)
1422
            with CountingMode() as mode:
1423
                with no_dispatch():
1424
                    torch.zeros_like(x)
1425

1426
        self.assertEqual(mode.count, 0)
1427

1428

1429
make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)
1430

1431

1432
class FakeTensorPropTest(TestCase):
1433
    def test_fake_tensor_prop_on_nn_module(self):
1434
        class ToyNnModuleWithParameters(torch.nn.Module):
1435
            def __init__(self) -> None:
1436
                super().__init__()
1437
                self.layer1 = torch.nn.Linear(4, 3)
1438
                self.layer2 = torch.nn.Linear(3, 2)
1439

1440
            def forward(self, value):
1441
                value = self.layer1(value)
1442
                value = torch.relu(value)
1443
                value = self.layer2(value)
1444
                return value
1445

1446
        model = ToyNnModuleWithParameters()
1447
        value = torch.randn(5, 4)
1448
        # Convert nn.Module to GraphModule so that FakeTensorProp runs.
1449
        graph_model = torch.fx.symbolic_trace(model, (value,))
1450
        # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode
1451
        #
1452
        # TODO(wschin): there should be an API to run FakeTensorProp for GraphModule
1453
        # with parameters and buffers.
1454
        with FakeTensorMode() as fake_tensor_mode:
1455

1456
            def to_fake_tensor(x):
1457
                if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
1458
                    return fake_tensor_mode.from_tensor(x)
1459
                return x
1460

1461
            fake_parameters_and_buffers = {
1462
                k: to_fake_tensor(v)
1463
                for k, v in itertools.chain(
1464
                    graph_model.named_parameters(), graph_model.named_buffers()
1465
                )
1466
            }
1467
            with torch.nn.utils.stateless._reparametrize_module(
1468
                graph_model, fake_parameters_and_buffers
1469
            ):
1470
                # This case uses the **same** fake tensor mode to
1471
                #  1. create fake parameters and fake buffers, and
1472
                #  2. run FakeTensorProp
1473
                # The result should be correct.
1474
                result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
1475
                self.assertTrue(isinstance(result, FakeTensor))
1476
                self.assertEqual(result.shape, (5, 2))
1477
                # This case uses the **different** fake tensor modes to
1478
                #  1. create fake parameters and fake buffers, and
1479
                #  2. run FakeTensorProp
1480
                # The following code should fail.
1481
                failed = False
1482
                try:
1483
                    FakeTensorProp(graph_model).propagate(value)
1484
                except AssertionError:
1485
                    # AssertionError: tensor's device must be `meta`, got cpu instead
1486
                    failed = True
1487
                self.assertTrue(failed)
1488

1489
    def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
1490
        class OptionalArgumentInBetween(torch.nn.Module):
1491
            def __init__(self) -> None:
1492
                super().__init__()
1493
                self.layer1 = torch.nn.Linear(4, 3)
1494
                self.layer2 = torch.nn.Linear(3, 2)
1495

1496
            def forward(self, value, another_value=None, another_optional_value=None):
1497
                # Mimic huggingface's `forward` methods which have several optional arguments.
1498
                # For example, GPT accepts forward(self, input_ids, None, attention_mask, ...).
1499
                # To apply FakeTensorProp, its from_real_tensor(...) needs to accept None.
1500
                if another_value is None:
1501
                    another_value = torch.rand_like(value)
1502
                if another_optional_value is None:
1503
                    another_optional_value = torch.rand_like(value)
1504
                value = value + another_value + another_optional_value
1505
                return value * value
1506

1507
        fake_mode = FakeTensorMode(
1508
            allow_non_fake_inputs=True, allow_fallback_kernels=False
1509
        )
1510
        with fake_mode:
1511
            model = OptionalArgumentInBetween()
1512
            value = torch.randn(5, 4)
1513
            another_optional_value = torch.randn(5, 4)
1514
            graph_model = torch.fx.symbolic_trace(
1515
                model, (value, None, another_optional_value)
1516
            )
1517
            FakeTensorProp(graph_model, fake_mode).propagate(
1518
                value, None, another_optional_value
1519
            )
1520

1521
    def test_unbacked_shape_realloc(self):
1522
        def f(x):
1523
            return x.nonzero()
1524

1525
        shape_env = ShapeEnv()
1526
        fake_mode = FakeTensorMode(shape_env=shape_env)
1527
        with fake_mode:
1528
            value = torch.randn(5)
1529
            gm = make_fx(f)(value)
1530
        nonzero_nodes = [
1531
            n for n in gm.graph.nodes if n.target is torch.ops.aten.nonzero.default
1532
        ]
1533
        self.assertEqual(len(nonzero_nodes), 1)
1534
        self.assertIsInstance(nonzero_nodes[0].meta["val"].shape[0], torch.SymInt)
1535
        u0 = nonzero_nodes[0].meta["val"].shape[0]
1536
        FakeTensorProp(gm, fake_mode).propagate(value)
1537
        u1 = nonzero_nodes[0].meta["val"].shape[0]
1538
        # Test that this test is actually doing something in that the
1539
        # FakeTensorProp actually triggered a reallocation.  If this assert is
1540
        # failing, it could be because we started memoizing the nnz count for
1541
        # nonzero, which is nice in some sense (no reallocation) but not
1542
        # helpful for this test, which is checking what we do when we have
1543
        # to reallocate.  If so, you need to make this example more
1544
        # complicated (e.g., maybe have a nontrivial computation on the input
1545
        # before feeding it into nonzero, or have some sort of randomness)
1546
        self.assertIsNot(u0, u1)
1547
        self.assertTrue(statically_known_true(u0 == u1))
1548

1549
    def test_torch_load_with_fake_mode(self):
1550
        class TheModelClass(torch.nn.Module):
1551
            def __init__(self) -> None:
1552
                super().__init__()
1553
                self.fc1 = torch.nn.Linear(5, 10)
1554

1555
            def forward(self, x):
1556
                return self.fc1(x)
1557

1558
        with TemporaryFileName() as state_dict_file:
1559
            # Create state_dict to be loaded later
1560
            model = TheModelClass()
1561
            torch.save(model.state_dict(), state_dict_file)
1562

1563
            fake_mode = FakeTensorMode()
1564
            with fake_mode:
1565
                torch.load(state_dict_file)  # scenario 1
1566
                torch.load(state_dict_file, map_location="cpu")  # scenario 2
1567

1568

1569
make_propagate_real_tensors_cls(FakeTensorPropTest)
1570

1571

1572
class FakeTensorSerialization(TestCase):
1573
    def test_serialization(self):
1574
        x = torch.tensor([0], device="cpu")
1575
        with FakeTensorMode():
1576
            y = pickle.loads(pickle.dumps(x))
1577
            self.assertEqual(type(y), FakeTensor)
1578
            self.assertEqual(y.device.type, "meta")
1579

1580
            with unset_fake_temporarily():
1581
                y = pickle.loads(pickle.dumps(x))
1582
                self.assertEqual(x.device, y.device)
1583

1584
    def test_serialization_with_tracing(self):
1585
        x = torch.tensor([0], device="cpu")
1586
        with tracing(TracingContext(FakeTensorMode())):
1587
            y = pickle.loads(pickle.dumps(x))
1588
            self.assertEqual(x.device, y.device)
1589

1590

1591
class FakeTensorDispatchCache(TestCase):
1592
    def test_shape_env_settings(self):
1593
        """
1594
        Validation that any boolean settings in ShapeEnv are present in the
1595
        ShapeEnvSettings. We hope to ensure that any new settings that might
1596
        affect FakeTensor dispatch are included in the cache key calculation.
1597
        If this test fails, consider updating ShapeEnvSettings or change this
1598
        test to omit checking for the new field.
1599
        """
1600
        init_sig = inspect.signature(ShapeEnv._init)
1601
        args = [
1602
            name
1603
            for name, param in init_sig.parameters.items()
1604
            if type(param.default) is bool
1605
        ]
1606

1607
        settings = [f.name for f in dataclasses.fields(ShapeEnvSettings)]
1608
        for arg in args:
1609
            self.assertTrue(arg in settings)
1610

1611
    def _test_cache_key(self, fm, x, y, z):
1612
        """
1613
        Helper for all test_cache_key_* tests below. Assert that the
1614
        cache keys for inputs x and y are the same, but z is different.
1615
        """
1616
        func = aten.add.Tensor
1617
        state = _CacheKeyState()
1618
        key_x = fm._cache_key(state, func, [x], {})
1619
        key_y = fm._cache_key(state, func, [y], {})
1620
        key_z = fm._cache_key(state, func, [z], {})
1621

1622
        self.assertEqual(key_x, key_y)
1623
        self.assertNotEqual(key_x, key_z)
1624

1625
    def test_cache_key_dtype(self):
1626
        with FakeTensorMode() as fm:
1627
            x = torch.randn(4, 3, dtype=torch.float16)
1628
            y = torch.randn(4, 3, dtype=torch.float16)
1629
            z = x.to(dtype=torch.float32)
1630
            self._test_cache_key(fm, x, y, z)
1631

1632
    def test_cache_key_shape(self):
1633
        with FakeTensorMode() as fm:
1634
            x = torch.randn(4, 3)
1635
            y = torch.randn(4, 3)
1636
            z = torch.randn(4, 2)
1637
            self._test_cache_key(fm, x, y, z)
1638

1639
    def test_cache_key_stride(self):
1640
        with FakeTensorMode() as fm:
1641
            x = torch.randn(4, 2)
1642
            y = torch.randn(4, 2)
1643
            z = x.as_strided((4, 2), (1, 2))
1644
            self._test_cache_key(fm, x, y, z)
1645

1646
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1647
    def test_cache_key_device(self):
1648
        with FakeTensorMode() as fm:
1649
            x = torch.randn(4, 3)
1650
            y = torch.randn(4, 3)
1651
            z = x.to(device="cuda")
1652
            self._test_cache_key(fm, x, y, z)
1653

1654
    def test_cache_key_memory_format(self):
1655
        with FakeTensorMode() as fm:
1656
            x = torch.randn(1, 2, 3, 4)
1657
            y = torch.randn(1, 2, 3, 4)
1658
            z = x.to(memory_format=torch.channels_last)
1659
            self._test_cache_key(fm, x, y, z)
1660

1661
    def test_cache_key_storage_offset(self):
1662
        with FakeTensorMode() as fm:
1663
            x = torch.randn(3)[1:]
1664
            y = torch.randn(3)[1:]
1665
            z = torch.randn(2)
1666
            self._test_cache_key(fm, x, y, z)
1667

1668
    def test_cache_key_requires_grad(self):
1669
        with FakeTensorMode() as fm:
1670
            x = torch.randn(4, 3)
1671
            y = torch.randn(4, 3)
1672
            z = torch.randn(4, 3, requires_grad=True)
1673
            self._test_cache_key(fm, x, y, z)
1674

1675
    def test_cache_key_is_conj(self):
1676
        with FakeTensorMode() as fm:
1677
            x = torch.randn(4, 3, dtype=torch.complex64)
1678
            y = torch.randn(4, 3, dtype=torch.complex64)
1679
            z = torch.randn(4, 3, dtype=torch.complex64)
1680
            torch._C._set_conj(z, not z.is_conj())
1681
            self._test_cache_key(fm, x, y, z)
1682

1683
    def test_cache_key_is_neg(self):
1684
        with FakeTensorMode() as fm:
1685
            x = torch.randn(4, 3, dtype=torch.complex64)
1686
            y = torch.randn(4, 3, dtype=torch.complex64)
1687
            z = torch.randn(4, 3, dtype=torch.complex64)
1688
            torch._C._set_neg(z, not z.is_neg())
1689
            self._test_cache_key(fm, x, y, z)
1690

1691
    def test_cache_key_is_inference(self):
1692
        with torch.inference_mode(True):
1693
            t = torch.randn(4, 3)
1694
        with FakeTensorMode() as fm:
1695
            x = torch.randn(4, 3)
1696
            y = torch.randn(4, 3)
1697
            z = fm.from_tensor(t)
1698
            self._test_cache_key(fm, x, y, z)
1699

1700
    def test_cache_key_constants(self):
1701
        with FakeTensorMode() as fm:
1702
            # Python hashes 1.0 to the same value as 1. Make sure the
1703
            # cache key calculation differentiates them.
1704
            self._test_cache_key(fm, 1.0, 1.0, 1)
1705
            self._test_cache_key(fm, 0.0, 0.0, 0)
1706

1707
    def assertHitsMisses(self, hits, misses):
1708
        """
1709
        Helper to assert on the number of recorded hits and misses.
1710
        """
1711
        info = FakeTensorMode.cache_info()
1712
        self.assertEqual(info.hits, hits)
1713
        self.assertEqual(info.misses, misses)
1714

1715
    def assertBypasses(self, reason, count):
1716
        """
1717
        Helper to assert on the number of recorded bypasses.
1718
        """
1719
        info = FakeTensorMode.cache_info()
1720
        if count > 0:
1721
            self.assertIn(reason, info.bypasses)
1722
            self.assertEqual(info.bypasses[reason], count)
1723
        else:
1724
            self.assertNotIn(reason, info.bypasses)
1725

1726
    def test_cache_hit(self):
1727
        """
1728
        Test that cache hit/miss counters are updated correctly.
1729
        """
1730
        with FakeTensorMode():
1731
            x = torch.randn(4, 3)
1732
            y = torch.randn(4, 3)
1733

1734
            FakeTensorMode.cache_clear()
1735
            self.assertHitsMisses(0, 0)
1736
            res1 = x + y
1737
            self.assertHitsMisses(0, 1)
1738
            res2 = x + y
1739
            self.assertHitsMisses(1, 1)
1740

1741
            self.assertEqual(
1742
                extract_tensor_metadata(res1),
1743
                extract_tensor_metadata(res2),
1744
            )
1745

1746
    def test_cache_bypass(self):
1747
        """
1748
        Test that cache bypass counters are updated correctly.
1749
        """
1750
        with FakeTensorMode():
1751
            x = torch.randn(1, 2)
1752

1753
            FakeTensorMode.cache_clear()
1754
            self.assertBypasses("inplace view", 0)
1755

1756
            x.unsqueeze_(0)
1757
            self.assertBypasses("inplace view", 1)
1758

1759
    def test_cache_default_dtype(self):
1760
        """
1761
        Test that the default dtype is respected when serving cached results.
1762
        """
1763
        with FakeTensorMode():
1764
            x = torch.tensor([1, 2], dtype=torch.int32)
1765
            torch.set_default_dtype(torch.float32)
1766

1767
            FakeTensorMode.cache_clear()
1768
            self.assertHitsMisses(0, 0)
1769

1770
            y = x + 1.0
1771
            self.assertEqual(y.dtype, torch.float32)
1772
            self.assertHitsMisses(0, 1)
1773

1774
            torch.set_default_dtype(torch.float16)
1775
            y = x + 1.0
1776
            self.assertEqual(y.dtype, torch.float16)
1777
            self.assertHitsMisses(0, 2)
1778

1779
            torch.set_default_dtype(torch.float32)
1780
            y = x + 1.0
1781
            self.assertEqual(y.dtype, torch.float32)
1782
            self.assertHitsMisses(1, 2)
1783

1784
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1785
    def test_cache_default_device(self):
1786
        """
1787
        Test that the default device is respected when serving cached results.
1788
        """
1789
        with FakeTensorMode():
1790
            FakeTensorMode.cache_clear()
1791
            self.assertHitsMisses(0, 0)
1792

1793
            torch.set_default_device("cpu")
1794
            x = torch.tensor([1, 2])
1795
            y = x + 1.0
1796
            self.assertEqual(y.device.type, "cpu")
1797
            self.assertHitsMisses(0, 1)
1798

1799
            torch.set_default_device("cuda")
1800
            x = torch.tensor([1, 2])
1801
            y = x + 1.0
1802
            self.assertEqual(y.device.type, "cuda")
1803
            self.assertHitsMisses(0, 2)
1804

1805
            torch.set_default_device("cpu")
1806
            x = torch.tensor([1, 2])
1807
            y = x + 1.0
1808
            self.assertEqual(y.device.type, "cpu")
1809
            self.assertHitsMisses(1, 2)
1810

1811
    def test_cache_inplace_op(self):
1812
        """
1813
        Test that inplace ops served from the cache correctly reference the
1814
        input parameter.
1815
        """
1816
        with FakeTensorMode():
1817
            x = torch.randn(1, 2)
1818
            y = torch.randn(1, 2)
1819

1820
            FakeTensorMode.cache_clear()
1821
            self.assertHitsMisses(0, 0)
1822

1823
            z = x.add_(y)
1824
            self.assertHitsMisses(0, 1)
1825
            self.assertEqual(id(x), id(z))
1826

1827
            w = x.add_(y)
1828
            self.assertHitsMisses(1, 1)
1829
            self.assertEqual(id(x), id(w))
1830

1831
    def test_cache_view_op(self):
1832
        """
1833
        Test that view ops are handled correctly when served from the cache.
1834
        """
1835
        with FakeTensorMode():
1836
            x1 = torch.ones(2, requires_grad=True).clone()
1837
            x2 = torch.ones(2, requires_grad=True).clone()
1838
            y2 = x2.view(-1)
1839

1840
            # Test operating on a non-view tensor, then the same operation
1841
            # on a view tensor. Assert that the view property is set correctly.
1842
            z1 = x1.mul_(2)
1843
            self.assertFalse(z1._is_view())
1844

1845
            z2 = y2.mul_(2)
1846
            self.assertTrue(z2._is_view())
1847

1848
            # Now the other way around: first operate on a view tensor, then
1849
            # the same operation on a non-view tensor.
1850
            z2 = y2.mul_(2)
1851
            self.assertTrue(z2._is_view())
1852

1853
            z1 = x1.mul_(2)
1854
            self.assertFalse(z1._is_view())
1855

1856
    def test_cache_dispatch_key_set(self):
1857
        """
1858
        Test that operations that change the dispatch key set bypass caching.
1859
        """
1860
        with FakeTensorMode():
1861
            FakeTensorMode.cache_clear()
1862
            self.assertBypasses("dispatch_key_set mismatch", 0)
1863

1864
            x = torch._efficientzerotensor(3)
1865
            self.assertTrue(x._is_zerotensor())
1866
            self.assertBypasses("dispatch_key_set mismatch", 1)
1867

1868
            y = torch._efficientzerotensor(3)
1869
            self.assertTrue(y._is_zerotensor())
1870
            self.assertBypasses("dispatch_key_set mismatch", 2)
1871

1872
    def test_inference_mode(self):
1873
        """
1874
        Test that caching handles inference mode correctly.
1875
        """
1876
        with FakeTensorMode():
1877
            x = torch.randn(4, 3)
1878
            y = torch.randn(4, 3)
1879

1880
            FakeTensorMode.cache_clear()
1881
            self.assertHitsMisses(0, 0)
1882

1883
            # Expect a miss when the inference mode is different
1884
            res1 = x + y
1885
            with torch.inference_mode():
1886
                res2 = x + y
1887

1888
            self.assertHitsMisses(0, 2)
1889
            self.assertFalse(res1.is_inference())
1890
            self.assertTrue(res2.is_inference())
1891

1892
            # Second tries should see hits
1893
            res3 = x + y
1894

1895
            self.assertHitsMisses(1, 2)
1896
            self.assertFalse(res3.is_inference())
1897
            self.assertEqual(
1898
                extract_tensor_metadata(res1),
1899
                extract_tensor_metadata(res3),
1900
            )
1901

1902
            with torch.inference_mode():
1903
                res4 = x + y
1904

1905
            self.assertHitsMisses(2, 2)
1906
            self.assertTrue(res4.is_inference())
1907
            self.assertEqual(
1908
                extract_tensor_metadata(res2),
1909
                extract_tensor_metadata(res4),
1910
            )
1911

1912

1913
if __name__ == "__main__":
1914
    run_tests()
1915

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.