pytorch

Форк
0
/
test_fake_tensor.py 
1674 строки · 62.3 Кб
1
# Owner(s): ["module: meta tensors"]
2

3
import sys
4

5
from torch.testing._internal.common_utils import (
6
    TestCase, TEST_WITH_TORCHDYNAMO, run_tests, skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, parametrize,
7
    instantiate_parametrized_tests, TemporaryFileName)
8
import torch
9
import torch._dynamo
10
import itertools
11
import numpy as np
12
from torch.testing._internal.jit_utils import RUN_CUDA
13
from torch._guards import tracing, TracingContext
14
from torch._subclasses.fake_tensor import (
15
    _ShapeEnvSettings,
16
    extract_tensor_metadata,
17
    FakeTensor,
18
    FakeTensorMode,
19
    FakeTensorConverter,
20
    DynamicOutputShapeException,
21
    UnsupportedOperatorException,
22
    unset_fake_temporarily,
23
)
24
from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols, StatelessSymbolicContext
25
from torch.testing._internal.custom_op_db import custom_op_db
26
from torch.testing._internal.common_device_type import ops
27
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
28
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
29
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
30
from torch._dynamo.testing import rand_strided
31
from torch._C._functorch import is_batchedtensor, _add_batch_dim, get_unwrapped
32
from torch.testing import FileCheck
33
import dataclasses
34
import inspect
35
import unittest
36
import torch._prims as prims
37
import contextlib
38
import weakref
39
import copy
40
import pickle
41
import torch._functorch.config
42
import torch.testing._internal.optests as optests
43
from unittest.mock import patch
44

45
from torch import distributed as dist
46
from torch.utils._mode_utils import no_dispatch
47
from torch.utils._python_dispatch import TorchDispatchMode
48
import torch.utils._pytree as pytree
49

50
aten = torch.ops.aten
51

52
torch._dynamo.config.fake_tensor_cache_enabled = True
53
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
54

55
class FakeTensorTest(TestCase):
56
    def checkType(self, t, device_str, size):
57
        self.assertTrue(isinstance(t, FakeTensor))
58
        self.assertEqual(t.device.type, device_str)
59
        self.assertEqual(list(t.size()), size)
60

61

62
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
63
    def test_cuda_initialized(self):
64
        # doesnt error
65
        with FakeTensorMode():
66
            p = torch.randn(4, 2, requires_grad=True, device='cuda')
67
            x = torch.randn(8, 4, device='cuda')
68
            y = torch.mm(x, p).square().sum()
69
            y.backward()
70

71
    def test_basic(self):
72
        x = torch.empty(2, 2, device="cpu")
73
        y = torch.empty(4, 2, 2, device="cpu")
74
        with FakeTensorMode() as mode:
75
            x = mode.from_tensor(x)
76
            y = mode.from_tensor(y)
77
            z = x + y
78
            self.assertEqual(z.shape, (4, 2, 2))
79
            self.assertEqual(z.device, torch.device("cpu"))
80
            self.assertTrue(isinstance(z, FakeTensor))
81

82
    def test_basic_forced_memo_only(self):
83
        x = torch.empty(2, 2, device="cpu")
84
        y = torch.empty(4, 2, 2, device="cpu")
85
        with FakeTensorMode() as mode:
86
            x_fake = mode.from_tensor(x)
87
            x2 = mode.from_tensor(x, memoized_only=True)
88
            self.assertTrue(x2 is not None)
89
            y = mode.from_tensor(y, memoized_only=True)
90
            self.assertIs(y, None)
91

92
    def test_custom_op_fallback(self):
93
        from torch.library import Library, impl
94

95
        test_lib = Library("my_test_op", "DEF")  # noqa: TOR901
96
        test_lib.define('foo(Tensor self) -> Tensor')
97

98
        @impl(test_lib, 'foo', 'CPU')
99
        def foo_impl(self):
100
            return self.cos()
101

102
        x = torch.empty(2, 2, device="cpu")
103
        with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
104
            with FakeTensorMode(allow_fallback_kernels=True) as mode:
105
                x = mode.from_tensor(x)
106
                torch.ops.my_test_op.foo(x)
107

108
    def test_parameter_instantiation(self):
109
        with FakeTensorMode():
110
            x = torch.rand([4])
111
            y = torch.nn.parameter.Parameter(x)
112
            self.assertTrue(isinstance(y, torch.nn.Parameter))
113

114
    @unittest.skipIf(not dist.is_available(), "requires distributed")
115
    def test_fsdp_flat_param(self):
116
        from torch.distributed.fsdp._flat_param import FlatParameter
117
        with FakeTensorMode() as m:
118
            data = torch.randn(2, 2)
119
            param = FlatParameter(data, requires_grad=True)
120
        self.assertIsInstance(param, FlatParameter)
121
        self.assertIsInstance(param, torch.nn.Parameter)
122
        self.assertIsInstance(param, FakeTensor)
123

124
    def test_non_parameter_grad(self):
125
        mode = FakeTensorMode()
126
        t = torch.rand([4], requires_grad=True)
127
        fake_t = mode.from_tensor(t)
128
        self.assertEqual(fake_t.requires_grad, t.requires_grad)
129

130
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
131
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
132
    def test_index_cuda_with_cpu(self):
133
        with FakeTensorMode():
134
            x = torch.rand([2048], device='cuda')
135
            out = x[torch.zeros([36], dtype=torch.int64)]
136
            self.checkType(out, "cuda", [36])
137

138
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
139
    def test_shape_take_not_device(self):
140
        with FakeTensorMode():
141
            x = torch.empty(1, device="cpu")
142
            y = torch.empty(8, 8, device="cuda")
143
            out = x.resize_as_(y)
144
            self.assertEqual(out.shape, (8, 8))
145
            self.assertEqual(out.device.type, "cpu")
146
            self.assertTrue(isinstance(out, FakeTensor))
147

148
    def test_repr(self):
149
        with FakeTensorMode():
150
            x = torch.empty(2, 2, device="cpu")
151
            self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
152
            x = torch.empty(2, 2, device="meta")
153
            self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
154

155
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
156
    def test_zero_dim(self):
157
        with FakeTensorMode() as mode:
158
            x = torch.tensor(0.)
159
            y = torch.rand([4, 4], device="cuda")
160
            out = x + y
161
            self.assertEqual(out.shape, (4, 4))
162
            self.assertEqual(out.device, y.device)
163
            self.assertTrue(isinstance(out, FakeTensor))
164

165
    def test_nan_to_num(self):
166
        with FakeTensorMode():
167
            for dtype in [torch.float16, torch.float32]:
168
                x = torch.rand([4], dtype=dtype)
169
                y = torch.nan_to_num(x, nan=None)
170
                z = torch.nan_to_num(x, 0.0)
171
                self.assertEqual(dtype, y.dtype)
172
                self.assertEqual(dtype, z.dtype)
173

174
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
175
    def test_throw(self):
176
        x = torch.tensor(0.)  # TODO: tensor() errors
177
        with FakeTensorMode() as mode:
178
            x_conv = mode.from_tensor(x)
179
            y = torch.rand([4, 4], device="cuda")
180
            z = torch.rand([4, 4], device="cpu")
181
            self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
182

183
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
184
    def test_type_as(self):
185
        with FakeTensorMode():
186
            x = torch.rand([16, 1], device="cpu")
187
            y = torch.rand([4, 4], device="cuda")
188
            out = x.type_as(y)
189
            self.assertEqual(out.device.type, "cuda")
190
            self.assertTrue(isinstance(out, FakeTensor))
191

192
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
193
    def test_setitem(self):
194
        for device in ["cpu", "cuda"]:
195
            with FakeTensorMode():
196
                x = torch.rand([16, 1], device=device)
197
                x[..., 0] = 0
198

199
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
200
    def test_device_inplace_copy(self):
201
        with FakeTensorMode():
202
            x = torch.rand([8, 8], device="cpu")
203
            y = torch.rand([8, 8], device="cuda")
204
            assert x.copy_(y).device.type == "cpu"
205
            assert y.copy_(x).device.type == "cuda"
206

207
    def test_fake_dispatch_keys(self):
208
        with FakeTensorMode():
209
            x = torch.rand([4])
210
            f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
211
            f.run(torch._C._dispatch_key_set(x))
212

213
            with torch.inference_mode():
214
                x = torch.rand([4])
215
                y = x + x
216
                FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
217
                FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
218

219
    def test_batch_tensor(self):
220
        x = torch.rand((3, 4, 5))
221
        b = _add_batch_dim(x, 0, 0)
222
        mode = FakeTensorMode()
223
        fake_b = mode.from_tensor(b)
224
        prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
225

226
        b1 = _add_batch_dim(x, 1, 1)
227
        b2 = _add_batch_dim(b1, 0, 2)
228
        fake_b2 = mode.from_tensor(b2)
229
        prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
230
        self.assertTrue(is_batchedtensor(fake_b2))
231
        fake_b1 = get_unwrapped(fake_b2)
232
        self.assertTrue(is_batchedtensor(fake_b1))
233
        fake_tensor = get_unwrapped(fake_b1)
234
        self.assertIsInstance(fake_tensor, FakeTensor)
235

236
    def test_constructor(self):
237
        with FakeTensorMode():
238
            x = torch.rand([4, 4], device="cpu")
239

240
        self.assertTrue(isinstance(x, FakeTensor))
241
        self.assertTrue(x.device.type == "cpu")
242

243
    def test_mode(self):
244
        with FakeTensorMode():
245
            y = torch.rand([4], device="cpu")
246
            out = y + y
247

248
        self.assertTrue(isinstance(out, FakeTensor))
249

250
    def test_full(self):
251
        # Test torch.full returns tensor with correct dtype
252
        with torch._subclasses.CrossRefFakeMode():
253
            y = torch.full((4, 4), 1)
254

255
    def check_function_with_fake(self, fn):
256
        out = fn()
257
        with torch._subclasses.FakeTensorMode():
258
            out_fake = fn()
259

260
        for a, b in zip(pytree.tree_leaves(out), pytree.tree_leaves(out_fake)):
261
            if not isinstance(a, torch.Tensor):
262
                self.assertTrue(not isinstance(b, torch.Tensor))
263
                continue
264

265
            prims.utils.compare_tensor_meta(a, b, check_strides=True)
266

267
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
268
    def test_non_kwarg_device(self):
269
        with FakeTensorMode():
270
            x = torch.rand([16, 1], device="cpu")
271
            y = x.to(torch.device("cpu"))
272
            self.assertIs(x, y)
273
            z = x.to(torch.device("cuda"))
274
            self.assertEqual(z.device.type, "cuda")
275

276
    def test_non_overlapping_stride_zero(self):
277
        def foo():
278
            x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
279
            return x.half()
280

281
        self.check_function_with_fake(foo)
282

283
    def test_fake_mode_error(self):
284
        x = torch.rand([4, 4])
285

286
        with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
287
            with FakeTensorMode():
288
                y = x[0]
289

290
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
291
    def test_fake_grad_copy(self):
292
        x = torch.rand([4, 4], requires_grad=True)
293
        x.grad = torch.rand([4, 4])
294
        mode = FakeTensorMode()
295
        fake_x = mode.from_tensor(x)
296
        prims.utils.compare_tensor_meta(fake_x, x)
297
        prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
298

299
        self.assertTrue(isinstance(fake_x.grad, FakeTensor))
300

301
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
302
    def test_index_put_error(self):
303
        mode = FakeTensorMode()
304
        for context in [contextlib.nullcontext, lambda: mode]:
305
            with context():
306
                y = torch.randn(2, 2, 3)
307
                x = torch.randn(2, 2, 3).to('cuda')
308
                with self.assertRaises(RuntimeError):
309
                    x[[1, 1]] = y
310

311
                with self.assertRaises(RuntimeError):
312
                    torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), y)
313

314
                # no error
315
                torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.))
316
                torch.ops.aten.index_put_(x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.))
317

318

319

320
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
321
    def test_like_constructor(self):
322
        with FakeTensorMode():
323
            x = torch.rand([4, 4])
324
            y = torch.ones_like(x)
325
            self.assertTrue(isinstance(y, FakeTensor))
326
            self.assertEqual(y.device.type, "cpu")
327
            z = torch.ones_like(x, device="cuda")
328
            self.assertTrue(isinstance(z, FakeTensor))
329
            self.assertEqual(z.device.type, "cuda")
330

331
    def test_binary_op_type_promotion(self):
332
        with FakeTensorMode():
333
            x = torch.empty([2, 2], dtype=torch.float)
334
            y = torch.empty([2, 2], dtype=torch.int64)
335
            out = x / y
336
            self.assertEqual(out.dtype, torch.float)
337
            self.assertEqual(out.device.type, "cpu")
338

339
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
340
    def test_from_numpy(self):
341
        with FakeTensorMode():
342
            x = torch.tensor(np.zeros([4, 4]))
343
            self.checkType(x, "cpu", [4, 4])
344

345
    def test_randperm(self):
346
        x = torch.randperm(10)
347
        y = torch.randperm(5, device="cpu")
348
        with FakeTensorMode():
349
            x1 = torch.randperm(10)
350
            prims.utils.compare_tensor_meta(x, x1)
351
            y1 = torch.randperm(5, device="cpu")
352
            prims.utils.compare_tensor_meta(y, y1)
353

354
    def test_print_in_fake_mode(self):
355
        x = torch.zeros(2)
356
        # does not fail
357
        with FakeTensorMode():
358
            out = str(x)
359
        assert "FakeTensor" not in out
360

361
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
362
    def test_upsample_bilinear_small_channels(self):
363
        out = []
364
        mode = FakeTensorMode()
365
        for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
366
            with context():
367
                arg0_1 = torch.empty_strided((3, 427, 640), (1, 1920, 3), dtype=torch.float32, device='cuda')
368
                unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
369
                out.append(torch.ops.aten.upsample_bilinear2d.default(unsqueeze, [800, 1199], False))
370

371
        self.assertTrue(out[1].is_contiguous())
372
        self.checkMetaProps(out[0], out[1])
373

374
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
375
    def test_cpu_fallback(self):
376
        with FakeTensorMode(allow_fallback_kernels=False):
377
            filters = torch.randn(8, 4, 3, 3).cuda()
378
            inputs = torch.randn(1, 4, 5, 5).cuda()
379
            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
380
            self.assertEqual(out.device.type, "cuda")
381
            self.assertEqual(list(out.size()), [1, 8, 5, 5])
382

383
        with FakeTensorMode(allow_fallback_kernels=True):
384
            # intentionally bad inputs
385
            filters = torch.randn(8, 20, 3, 3).cuda()
386
            inputs = torch.randn(1, 7, 10, 5).cuda()
387
            with self.assertRaises(RuntimeError):
388
                torch.nn.functional.conv2d(inputs, filters, padding=1)
389

390
        with FakeTensorMode(allow_fallback_kernels=True):
391
            filters = torch.randn(8, 4, 3, 3).cuda()
392
            inputs = torch.randn(1, 4, 5, 5).cuda()
393

394
            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
395
            self.assertEqual(out.device.type, "cuda")
396
            self.assertEqual(list(out.size()), [1, 8, 5, 5])
397

398
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
399
    def test_out_multi_device(self):
400
        with FakeTensorMode():
401
            x = torch.rand([4])
402
            y = torch.rand([4], device="cuda")
403

404
            with self.assertRaisesRegex(Exception, "found two different devices"):
405
                torch.sin(x, out=y)
406

407
            with self.assertRaisesRegex(Exception, "found two different devices"):
408
                x.add_(y)
409

410

411
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
412
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
413
    def test_normalize_device(self):
414
        with FakeTensorMode():
415
            x = torch.empty(1, device="cuda")
416
            y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}")
417
            out = x + y
418
        self.checkType(out, "cuda", [1])
419

420
    def test_recursive_invocation(self):
421
        mode = FakeTensorMode()
422
        with mode:
423
            x = torch.tensor(2)
424
            mode.in_kernel_invocation = True
425
            y = x + x
426
            self.assertTrue(mode.in_kernel_invocation)
427

428
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
429
    @skipIfRocm
430
    @parametrize("allow_fallback_kernels", [False, True],
431
                 lambda a: 'with_fallback' if a else 'without_fallback')
432
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
433
    def test_cudnn_rnn(self, allow_fallback_kernels):
434
        def fn(
435
            a0,
436
            b0,
437
            b1,
438
            b2,
439
            b3,
440
            b4,
441
            b5,
442
            b6,
443
            b7,
444
            b8,
445
            b9,
446
            b10,
447
            b11,
448
            b12,
449
            b13,
450
            b14,
451
            b15,
452
            a3,
453
            a4,
454
            a5,
455
        ):
456
            a1 = [
457
                b0,
458
                b1,
459
                b2,
460
                b3,
461
                b4,
462
                b5,
463
                b6,
464
                b7,
465
                b8,
466
                b9,
467
                b10,
468
                b11,
469
                b12,
470
                b13,
471
                b14,
472
                b15,
473
            ]
474
            return torch.ops.aten._cudnn_rnn(
475
                a0,
476
                a1,
477
                4,
478
                a3,
479
                a4,
480
                a5,
481
                2,
482
                2048,
483
                0,
484
                2,
485
                False,
486
                0.0,
487
                False,
488
                True,
489
                [],
490
                None,
491
            )
492

493
        mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels)
494
        for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
495
            with context():
496
                inps1 = [
497
                    torch.randn([92, 8, 2048]).cuda(),
498
                    torch.randn([8192, 2048]).cuda(),
499
                    torch.randn([8192, 2048]).cuda(),
500
                    torch.randn([8192]).cuda(),
501
                    torch.randn([8192]).cuda(),
502
                    torch.randn([8192, 2048]).cuda(),
503
                    torch.randn([8192, 2048]).cuda(),
504
                    torch.randn([8192]).cuda(),
505
                    torch.randn([8192]).cuda(),
506
                    torch.randn([8192, 4096]).cuda(),
507
                    torch.randn([8192, 2048]).cuda(),
508
                    torch.randn([8192]).cuda(),
509
                    torch.randn([8192]).cuda(),
510
                    torch.randn([8192, 4096]).cuda(),
511
                    torch.randn([8192, 2048]).cuda(),
512
                    torch.randn([8192]).cuda(),
513
                    torch.randn([8192]).cuda(),
514
                    torch.randn([167837696]).cuda(),
515
                    torch.randn([4, 8, 2048]).cuda(),
516
                    torch.randn([4, 8, 2048]).cuda(),
517
                ]
518
                inps2 = inps1
519
                inps2[len(inps2) - 1] = None  # argument `cx` can be None
520

521
                for inps in [inps1, inps2]:
522
                    out = fn(*inps)
523
                    self.assertIs(out[4], inps[-3])
524
                    for ten in out:
525
                        if i == 1:
526
                            self.assertTrue(isinstance(ten, FakeTensor))
527
                        self.assertEqual(ten.device.type, 'cuda')
528

529
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
530
    def test_cuda_lstm(self):
531
        # Ensure CUDA (non-cuDNN) impl succeeds with fake tensors.
532
        with torch.backends.cudnn.flags(enabled=False):
533
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
534
            with fake_tensor_mode:
535
                N = 5
536
                L = 4
537
                H_in = 2
538
                hidden_size = 3
539
                proj_size = 2
540
                num_layers = 2
541
                bidir = False
542
                D = 2 if bidir else 1
543
                H_out = proj_size if proj_size > 0 else hidden_size
544

545
                lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
546
                                     num_layers=num_layers, proj_size=proj_size, batch_first=False,
547
                                     bias=True, bidirectional=bidir, device='cuda')
548

549
                h_0 = torch.randn((num_layers * D, N, H_out), device='cuda')
550
                c_0 = torch.randn((num_layers * D, N, hidden_size), device='cuda')
551
                inp = torch.randn((L, N, H_in), device='cuda')
552
                (output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
553
                output.sum().backward()
554

555
                self.assertEqual(output.shape, (L, N, D * H_out))
556
                self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
557
                self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
558

559
    def test_data_dependent_operator(self):
560
        with FakeTensorMode(allow_fallback_kernels=False):
561
            x = torch.rand([10, 10])
562

563
            self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
564

565
    def test_tolist(self):
566
        shape_env = ShapeEnv()
567
        with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env):
568
            x = torch.rand([10])
569
            x.tolist()
570

571
    def test_same_shape_env_preserved(self):
572
        shape_env = ShapeEnv()
573
        mode1 = FakeTensorMode(shape_env=shape_env)
574
        t1 = mode1.from_tensor(
575
            torch.randn(10),
576
            symbolic_context=StatelessSymbolicContext(
577
                dynamic_sizes=[DimDynamic.DYNAMIC],
578
                constraint_sizes=[None]
579
            )
580
        )
581
        mode2 = FakeTensorMode(shape_env=shape_env)
582
        t2 = mode2.from_tensor(t1)
583
        # t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here
584
        self.assertIsNot(t2, t1)
585
        self.assertIs(t1.fake_mode, mode1)
586
        self.assertIs(t2.fake_mode, mode2)
587
        self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
588
        self.assertEqual(str(t2.size(0)), str(t1.size(0)))
589

590
    def test_jagged_fake_to_fake_preserved(self):
591
        from torch.nested._internal.nested_tensor import jagged_from_list
592

593
        S0, S1, S2 = 3, 4, 5
594
        D = 4
595
        a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64)
596
        b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64)
597
        c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64)
598
        offsets = None
599
        jt, _ = jagged_from_list([a, b, c], offsets)
600
        shape_env = ShapeEnv()
601
        mode1 = FakeTensorMode(shape_env=shape_env)
602
        t1 = mode1.from_tensor(jt)
603
        mode2 = FakeTensorMode(shape_env=shape_env)
604
        t2 = mode2.from_tensor(t1)
605
        # It's not obvious that the invocation above makes it dynamic but it
606
        # does!
607
        self.assertTrue(free_symbols(t1.size()))
608
        self.assertIsNot(t2, t1)
609
        self.assertIs(t1.offsets().fake_mode, mode1)
610
        self.assertIs(t2.offsets().fake_mode, mode2)
611
        self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env)
612
        self.assertEqual(str(t2.size(1)), str(t1.size(1)))
613

614
    def checkMetaProps(self, t1, t2):
615
        prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
616

617
    @skipIfCrossRef
618
    def test_deepcopy(self):
619
        with FakeTensorMode() as mode:
620
            pass
621
        mod = torch.nn.BatchNorm2d(10)
622
        with torch._subclasses.fake_tensor.FakeCopyMode(mode):
623
            mod_copied = copy.deepcopy(mod)
624

625
        def check_copy(mod, mod_copied):
626
            for name, param in itertools.chain(mod.named_parameters(), mod.named_buffers()):
627
                param_copied = getattr(mod_copied, name)
628
                self.checkMetaProps(param, param_copied)
629
                self.assertTrue(isinstance(param_copied, FakeTensor))
630
                self.assertEqual(isinstance(param, torch.nn.Parameter), isinstance(param_copied, torch.nn.Parameter))
631
                self.assertEqual(param.requires_grad, param_copied.requires_grad)
632

633
        check_copy(mod, mod_copied)
634

635
        class ModuleNew(torch.nn.Module):
636
            def __init__(self):
637
                super().__init__()
638
                self.a = torch.rand([10, 2])
639
                self.b = self.a
640
                self.c = self.a[0]
641

642
        mod = ModuleNew()
643
        with torch._subclasses.fake_tensor.FakeCopyMode(mode):
644
            mod_copied = copy.deepcopy(mod)
645

646
        self.assertIs(mod_copied.a, mod_copied.b)
647
        self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
648

649
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
650
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
651
    def test_new(self):
652
        with FakeTensorMode():
653
            a = torch.rand([16, 1])
654
            self.checkType(a.new(10, 10), "cpu", [10, 10])
655
            self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
656
            b = torch.rand([4, 4], device='cuda')
657
            self.checkType(b.new(device='cuda'), "cuda", [0])
658
            self.checkType(a.new(torch.rand([1])), "cpu", [1])
659

660
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
661
    def test_scalar_inputs(self):
662
        with FakeTensorMode():
663
            self.checkType(torch.div(3, 2), "cpu", [])
664
            ten = torch.zeros(2, dtype=torch.int32) * 2.0
665
            self.assertEqual(ten.dtype, torch.float)
666
            self.checkType(ten, "cpu", [2])
667

668
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
669
    def test_allow_meta(self):
670
        def run_meta():
671
            with FakeTensorMode():
672
                x = torch.rand([4], device="meta")
673
                return x + x
674

675
        self.checkType(run_meta(), "meta", [4])
676

677
        with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
678
            self.assertRaises(Exception, run_meta)
679

680
    def test_embedding_bag_meta(self):
681
        def f():
682
            # This behavior was originally unintentional but we see people
683
            # relying on it
684
            embedding = torch.nn.EmbeddingBag(10, 3, mode='sum', device='meta')
685
            input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
686
            offsets = torch.tensor([0, 4], dtype=torch.long)
687
            return embedding(input, offsets)
688

689
        real_out = f()
690
        with FakeTensorMode():
691
            fake_out = f()
692

693
        for r, f in zip(real_out, fake_out):
694
            self.assertEqual(r.size(), f.size())
695
            self.assertEqual(r.device, f.device)
696

697
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
698
    def test_mixed_real_and_fake_inputs(self):
699
        class _TestPattern(torch.nn.Module):
700
            def __init__(self):
701
                super().__init__()
702
                self.conv = torch.nn.Conv2d(1, 1, 1)
703
                self.bn = torch.nn.BatchNorm2d(1)
704

705
            def forward(self, input):
706
                running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
707
                scale_factor = self.bn.weight / running_std
708
                weight_shape = [1] * len(self.conv.weight.shape)
709
                weight_shape[0] = -1
710
                bias_shape = [1] * len(self.conv.weight.shape)
711
                bias_shape[1] = -1
712
                scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
713
                zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
714
                conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
715
                conv_orig = conv / scale_factor.reshape(bias_shape)
716
                conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
717
                conv = self.bn(conv_orig)
718
                return conv
719

720
        example_inputs = (torch.randn(1, 1, 3, 3),)
721
        mod = _TestPattern()
722
        with FakeTensorMode(allow_non_fake_inputs=True):
723
            out = mod(torch.randn(1, 1, 3, 3))
724
        self.checkType(out, "cpu", (1, 1, 3, 3))
725

726
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
727
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
728
    def test_aten_copy_multi_device(self):
729
        with FakeTensorMode():
730
            x1 = torch.rand(4, device="cpu")
731
            x2 = torch.rand(4, device="cuda")
732
            copy1 = torch.ops.aten.copy.default(x1, x2)
733
            copy2 = torch.ops.aten.copy.default(x2, x1)
734
            out = torch.empty(4, device="cpu")
735
            torch.ops.aten.copy.out(x1, x2, out=out)
736
        self.checkType(copy1, "cpu", (4,))
737
        self.checkType(copy2, "cuda", (4,))
738
        self.checkType(out, "cpu", (4,))
739

740
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
741
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
742
    def test_aten_index_multi_device(self):
743
        with FakeTensorMode():
744
            x1 = torch.rand(4, 4, device="cpu")
745
            x2 = torch.rand(4, 4, device="cuda")
746
            i1 = torch.tensor([0, 1], device="cuda")
747
            i2 = torch.tensor([0, 1], device="cpu")
748
            r1 = torch.ops.aten.index(x1, i1)
749
            r2 = torch.ops.aten.index(x2, i2)
750

751
            y1 = torch.rand(4, device="cpu")
752
            y2 = torch.rand(4, device="cuda")
753
            j1 = torch.tensor([2], device="cuda")
754
            j2 = torch.tensor([2], device="cpu")
755
            r3 = torch.ops.aten.index_put.default(x1, j1, y1)
756
            r4 = torch.ops.aten.index_put.default(x2, j2, y2)
757
        self.checkType(r1, "cpu", ())
758
        self.checkType(r2, "cuda", ())
759
        self.checkType(r3, "cpu", (4, 4))
760
        self.checkType(r4, "cuda", (4, 4))
761

762
    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
763
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
764
    def test_aten_slice_scatter_multi_device(self):
765
        with FakeTensorMode():
766
            x1 = torch.rand(4, 4, device="cpu")
767
            y1 = torch.rand(2, 4, device="cuda")
768
            x2 = torch.rand(4, 4, device="cuda")
769
            y2 = torch.rand(2, 4, device="cpu")
770
            out = torch.empty(4, 4, device="cpu")
771
            r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
772
            r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
773
            r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
774
        self.checkType(r1, "cpu", (4, 4))
775
        self.checkType(r2, "cuda", (4, 4))
776
        self.checkType(r3, "cpu", (4, 4))
777
        self.checkType(out, "cpu", (4, 4))
778

779
    def test__adaptive_avg_pool2d_backward(self):
780
        with FakeTensorMode():
781
            grad_out = torch.rand(2, 3, 4, 4)
782
            inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
783
            grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
784
            self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)
785

786
    @unittest.skipIf(
787
        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
788
    )
789
    def test_export_numpy(self):
790
        class MyNumpyModel(torch.nn.Module):
791
            def forward(self, input):
792
                input = input.numpy()
793
                return input + np.random.randn(*input.shape)
794

795
        with FakeTensorMode():
796
            ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),))
797
            self.assertTrue(isinstance(ep, torch.export.ExportedProgram))
798

799

800
class FakeTensorConstHandling(TestCase):
801
    def assertConst(self, *args):
802
        for arg in args:
803
            self.assertTrue(arg.constant is not None)
804

805
    def assertNotConst(self, *args):
806
        for arg in args:
807
            self.assertTrue(arg.constant is None)
808

809
    def test_simple(self):
810
        with FakeTensorMode():
811
            x = torch.tensor(4.)
812
            self.assertEqual(x.item(), 4.)
813

814
    def test_inplace_add(self):
815
        with FakeTensorMode():
816
            x = torch.tensor(4.)
817
            y = x.add_(1)
818
            self.assertEqual(x.item(), 5.)
819
            self.assertEqual(y.item(), 5.)
820
            self.assertConst(x, y)
821

822
    def test_shared_storages(self):
823
        with FakeTensorMode():
824
            x = torch.tensor([4.])
825
            y = x[:]
826

827
            self.assertEqual(x.storage()._cdata, y.storage()._cdata)
828
            self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
829

830
    def test_constant_invalidation(self):
831
        with FakeTensorMode():
832
            x = torch.tensor([1.])
833
            self.assertConst(x)
834
            y = torch.rand([1])
835
            x.add_(y)
836
            self.assertNotConst(x)
837

838
    def test_inplace_view_invalidation(self):
839
        with FakeTensorMode():
840
            x = torch.tensor([1])
841
            self.assertConst(x)
842
            x.resize_([2])
843
            self.assertEqual(x.size(0), 2)
844
            self.assertNotConst(x)
845

846
    def test_fake_tensor_in_intlist_repro(self):
847

848
        def fn(tensors):
849
            max_size = torch.tensor([800, 1216], dtype=torch.int64)
850
            batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
851
            return tensors[0].new_full(batch_shape, 0.0)
852

853
        with self.assertRaises(torch._subclasses.fake_tensor.DataDependentOutputException):
854
            with torch._subclasses.fake_tensor.FakeTensorMode():
855
                a = torch.randn(3, 800, 1199)
856
                b = torch.randn(3, 800, 800)
857
                inputs = [a, b]
858
                ref = fn(inputs)
859

860
    def test_fake_tensor_batch_norm_cpu(self):
861
        with torch._subclasses.CrossRefFakeMode():
862
            m = torch.nn.Sequential(
863
                torch.nn.BatchNorm2d(10),
864
                torch.nn.ReLU(),
865
            )
866
            m.eval()
867
            out = m(torch.randn([2, 10, 8, 8]))
868

869
    def test_shared_storage_invalidation(self):
870
        with FakeTensorMode():
871
            x = torch.tensor([1.])
872
            y = x[:]
873
            self.assertConst(x, y)
874
            y.add_(torch.rand([1]))
875
            self.assertNotConst(x, y)
876

877
    def test_aliased_const_write(self):
878
        with FakeTensorMode():
879
            x = torch.tensor([1])
880
            y = x.expand([4])
881
            self.assertNotConst(y)
882
            y[0] = 1
883
            self.assertNotConst(x)
884

885
    def test_constant_propagate_through_functions(self):
886
        with FakeTensorMode():
887
            y = torch.div(4, 4, rounding_mode='trunc')
888
            self.assertConst(y)
889

890
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
891
    return maybe_contained_type.isSubtypeOf(type) or any(
892
        contains_type(e, maybe_contained_type) for e in type.containedTypes()
893
    )
894

895

896
class FakeTensorOpInfoTest(TestCase):
897
    @ops(custom_op_db, dtypes=OpDTypes.any_one)
898
    def test_fake(self, device, dtype, op):
899
        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
900
        for sample_input in sample_inputs_itr:
901
            args = (sample_input.input,) + sample_input.args
902
            kwargs = sample_input.kwargs
903
            optests.fake_check(op, args, kwargs)
904

905

906
class FakeTensorConverterTest(TestCase):
907
    def test_memoized_conversion_to_meta(self):
908
        x = torch.rand(2, 2, 2)
909
        mode = FakeTensorMode()
910
        self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
911

912
    def test_memoized_conversion_from_meta(self):
913
        x = torch.rand(2, 2).to(device="meta")
914
        mode = FakeTensorMode()
915
        converter = mode.fake_tensor_converter
916
        self.assertTrue(converter.from_meta_and_device(mode, x, "cpu") is converter.from_meta_and_device(mode, x, "cpu"))
917

918
    def test_separate_tensor_storages_view(self):
919
        x = torch.rand(2, 2, 2)
920
        y = x[0]
921
        mode = FakeTensorMode()
922
        converter = mode.fake_tensor_converter
923
        x_conv = converter(mode, x)
924
        y_conv = converter(mode, y)
925
        self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
926

927
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
928
    def test_separate_tensor_storages_non_view(self):
929
        x = torch.rand(2, 2, 2)
930
        y = torch.rand(4, 2)
931
        y.set_(x.storage())
932
        mode = FakeTensorMode()
933
        converter = mode.fake_tensor_converter
934
        x_conv = converter(mode, x)
935
        y_conv = converter(mode, y)
936
        stor_id = torch._C._storage_id(x_conv)
937
        self.assertEqual(stor_id, torch._C._storage_id(y_conv))
938
        del x
939
        self.assertEqual(len(converter.tensor_memo), 1)
940
        converter.meta_converter.check_for_expired_weak_storages()
941
        self.assertEqual(len(converter.meta_converter.storage_memo), 1)
942
        del y
943
        self.assertEqual(len(converter.tensor_memo), 0)
944
        converter.meta_converter.check_for_expired_weak_storages()
945
        self.assertEqual(len(converter.meta_converter.storage_memo), 0)
946

947

948
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
949
    def test_dead_weak_ref(self):
950
        x = torch.rand(2, 2, 2)
951
        y = x[0]
952
        mode = FakeTensorMode()
953
        converter = FakeTensorConverter()
954
        x_conv = converter(mode, x)
955
        x_conv_storage = torch._C._storage_id(x_conv)
956
        del x_conv
957
        self.assertFalse(x in converter.tensor_memo)
958
        y_conv = converter(mode, y)
959
        self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
960

961
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
962
    def test_dead_key(self):
963
        x = torch.rand(2, 2, 2)
964
        mode = FakeTensorMode()
965
        converter = FakeTensorConverter()
966
        x_conv = converter(mode, x)
967
        self.assertEqual(len(converter.tensor_memo), 1)
968
        x_conv2 = converter(mode, x)
969
        assert x_conv2 is x_conv
970
        del x
971
        self.assertEqual(len(converter.tensor_memo), 0)
972

973
    def test_no_active_mode(self):
974
        with FakeTensorMode() as mode:
975
            x = torch.empty(2, 2, device="cpu")
976
            y = torch.empty(2, 2, device="cpu")
977

978
        out = x + y
979
        self.assertEqual(mode, out.fake_mode)
980
        self.assertTrue(isinstance(out, FakeTensor))
981
        self.assertEqual(out.device.type, "cpu")
982

983
    def test_multiple_modes(self):
984
        t = torch.rand([4])
985
        t2 = torch.rand([4])
986
        with FakeTensorMode() as m:
987
            with FakeTensorMode() as m2:
988
                t_fake = m.from_tensor(t)
989
                t2_fake = m2.from_tensor(t2)
990

991
                with self.assertRaisesRegex(Exception, "Mixing fake modes"):
992
                    t_fake + t2_fake
993

994
    def test_separate_mode_error(self):
995
        with FakeTensorMode():
996
            x = torch.empty(2, 2, device="cpu")
997
        with FakeTensorMode():
998
            y = torch.empty(2, 2, device="cpu")
999
        self.assertRaises(Exception, lambda: x, y)
1000

1001
    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1002
    def test_no_ref_cycle(self):
1003
        x = torch.rand([4])
1004
        mode = FakeTensorMode()
1005
        y = mode.from_tensor(x)
1006
        self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
1007
        mode_weak = weakref.ref(mode)
1008
        y_weak = weakref.ref(mode)
1009
        del mode
1010
        del y
1011
        assert mode_weak() is None
1012
        assert y_weak() is None
1013

1014

1015
class FakeTensorOperatorInvariants(TestCase):
1016
    @staticmethod
1017
    def get_aten_op(schema):
1018
        namespace, name = schema.name.split("::")
1019
        overload = schema.overload_name if schema.overload_name else "default"
1020
        assert namespace == "aten"
1021
        return getattr(getattr(torch.ops.aten, name), overload)
1022

1023
    @staticmethod
1024
    def get_all_aten_schemas():
1025
        for schema in torch._C._jit_get_all_schemas():
1026
            namespace = schema.name.split("::")[0]
1027
            if namespace != "aten":
1028
                continue
1029
            yield schema
1030

1031
    def test_non_kwarg_only_device(self):
1032
        for schema in self.get_all_aten_schemas():
1033
            ten_type = torch._C.TensorType.get()
1034
            if not any(
1035
                contains_type(arg.type, ten_type)
1036
                for arg in itertools.chain(schema.arguments, schema.returns)
1037
            ):
1038
                continue
1039

1040
            opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1041
            has_non_kwarg_device = any(
1042
                not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1043
                for arg in schema.arguments
1044
            )
1045
            if has_non_kwarg_device:
1046
                self.assertTrue(
1047
                    self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
1048
                )
1049

1050
    def test_tensor_constructors_all_have_kwarg_device(self):
1051
        for schema in self.get_all_aten_schemas():
1052
            op = self.get_aten_op(schema)
1053
            if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
1054
                continue
1055

1056
            opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1057
            has_kwarg_device = any(
1058
                arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1059
                for arg in schema.arguments
1060
            )
1061

1062
            self.assertTrue(
1063
                has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
1064
            )
1065

1066
    @unittest.expectedFailure
1067
    def test_sparse_new(self):
1068
        with FakeTensorMode():
1069
            indices = torch.randn(1, 1, dtype=torch.int64)
1070
            values = torch.randn(1)
1071
            extra = (2,)
1072
            sparse = torch.randn(1).to_sparse()
1073
            # This used to segfault, now it does not, but it still raises an
1074
            # error
1075
            sparse2 = sparse.new(indices, values, extra)
1076

1077
    def test_tensor_new(self):
1078
        with FakeTensorMode():
1079
            x = torch.Tensor([1, 2, 3])
1080
        self.assertIsInstance(x, FakeTensor)
1081

1082
    def test_like_ops(self):
1083
        for schema in self.get_all_aten_schemas():
1084
            if "_like" == schema.name[-5:]:
1085
                op = self.get_aten_op(schema)
1086
                self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)
1087

1088
    def test_str_storage(self):
1089
        x = torch.zeros(3)
1090
        with FakeTensorMode() as m:
1091
            y = m.from_tensor(x)
1092
            self.assertExpectedInline(str(x.storage()), '''\
1093
 0.0
1094
 0.0
1095
 0.0
1096
[torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 3]''')
1097
            self.assertExpectedInline(str(y.storage()), '''\
1098
...
1099
[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]''')
1100

1101
        self.assertExpectedInline(str(y.storage()), '''\
1102
...
1103
[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]''')
1104

1105
    # at::_embedding_bag has no op info,
1106
    # and returns extra tensors that at::embedding bag throws away
1107
    def test_embedding_bag_private(self):
1108
        args = [
1109
            torch.ones(6, 1),
1110
            torch.ones(6, dtype=torch.int64),
1111
            torch.arange(2, dtype=torch.int64),
1112
            False,
1113
            2,  # mode = max
1114
        ]
1115

1116
        ref_out = torch.ops.aten._embedding_bag(*args)
1117
        with FakeTensorMode() as m:
1118
            meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
1119
            meta_out = torch.ops.aten._embedding_bag(*meta_args)
1120

1121
        self.assertEqual(len(ref_out), len(meta_out))
1122
        for ref_o, meta_o in zip(ref_out, meta_out):
1123
            self.assertEqual(ref_o.size(), meta_o.size())
1124

1125
    def test_cross_entropy_loss(self):
1126
        inp = torch.randn(3, 5)
1127
        target = torch.randint(5, (3,), dtype=torch.long)
1128
        weight = torch.rand(5)
1129
        fn = torch.nn.functional.cross_entropy
1130
        for w in (weight, None):
1131
            args = (inp, target, w)
1132
            ref = fn(*args)
1133
            with FakeTensorMode() as m:
1134
                meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
1135
                meta_out = torch.nn.functional.cross_entropy(*meta_args, label_smoothing=0.5)
1136

1137
            self.assertEqual(ref.size(), meta_out.size())
1138

1139
    @skipIfRocm
1140
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1141
    def test_flash_attention(self):
1142
        class Repro(torch.nn.Module):
1143
            def __init__(self):
1144
                super().__init__()
1145

1146
            def forward(self, arg1, arg2, arg3):
1147
                torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3, scale=0.17677669529663687)
1148

1149
        args_new = [
1150
            [
1151
                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1152
                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1153
                ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1154
            ],
1155
            [
1156
                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1157
                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1158
                ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1159
            ]
1160
        ]
1161
        for args_list in args_new:
1162
            args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
1163
                    (bsz, num_heads, seq_len, head_dim) in args_list]
1164
            try:
1165
                with torch._subclasses.CrossRefFakeMode():
1166
                    Repro()(*args)
1167
            except RuntimeError as e:
1168
                # We expect the cross ref to succed for the first output to fail
1169
                # for the rng state, see Note [Seed and Offset]
1170
                self.assertTrue("output[0]" not in str(e))
1171
                self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
1172

1173
    @skipIfRocm
1174
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1175
    def test_conv_c1_backward(self):
1176
        class Repro(torch.nn.Module):
1177
            def __init__(self):
1178
                super().__init__()
1179

1180
            def forward(self, arg1, arg2, arg3):
1181
                torch.ops.aten.convolution_backward.default(
1182
                    arg1,
1183
                    arg2,
1184
                    arg3,
1185
                    [1],
1186
                    [1, 1],
1187
                    [1, 1],
1188
                    [1, 1],
1189
                    False,
1190
                    [0, 0],
1191
                    1,
1192
                    [True, True, False],
1193
                )
1194

1195
        args_new = [
1196
            ((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "cuda"),
1197
            ((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "cuda"),
1198
            ((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "cuda"),
1199
        ]
1200
        args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
1201

1202
        with torch._subclasses.CrossRefFakeMode():
1203
            Repro()(*args)
1204

1205
    def test_no_dispatch_with_like_function(self):
1206
        class CountingMode(TorchDispatchMode):
1207
            def __init__(self):
1208
                self.count = 0
1209

1210
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1211
                self.count += 1
1212
                return func(*args, **kwargs)
1213

1214
        with FakeTensorMode():
1215
            x = torch.randn(2)
1216
            with CountingMode() as mode:
1217
                with no_dispatch():
1218
                    torch.zeros_like(x)
1219

1220
        self.assertEqual(mode.count, 0)
1221

1222

1223
class FakeTensorPropTest(TestCase):
1224
    def test_fake_tensor_prop_on_nn_module(self):
1225
        class ToyNnModuleWithParameters(torch.nn.Module):
1226
            def __init__(self):
1227
                super().__init__()
1228
                self.layer1 = torch.nn.Linear(4, 3)
1229
                self.layer2 = torch.nn.Linear(3, 2)
1230

1231
            def forward(self, value):
1232
                value = self.layer1(value)
1233
                value = torch.relu(value)
1234
                value = self.layer2(value)
1235
                return value
1236

1237
        model = ToyNnModuleWithParameters()
1238
        value = torch.randn(5, 4)
1239
        # Convert nn.Module to GraphModule so that FakeTensorProp runs.
1240
        graph_model = torch.fx.symbolic_trace(model, (value,))
1241
        # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode
1242
        #
1243
        # TODO(wschin): there should be an API to run FakeTensorProp for GraphModule
1244
        # with parameters and buffers.
1245
        with FakeTensorMode() as fake_tensor_mode:
1246

1247
            def to_fake_tensor(x):
1248
                if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
1249
                    return fake_tensor_mode.from_tensor(x)
1250
                return x
1251

1252
            fake_parameters_and_buffers = {
1253
                k: to_fake_tensor(v)
1254
                for k, v in itertools.chain(
1255
                    graph_model.named_parameters(), graph_model.named_buffers()
1256
                )
1257
            }
1258
            with torch.nn.utils.stateless._reparametrize_module(
1259
                graph_model, fake_parameters_and_buffers
1260
            ):
1261
                # This case uses the **same** fake tensor mode to
1262
                #  1. create fake parameters and fake buffers, and
1263
                #  2. run FakeTensorProp
1264
                # The result should be correct.
1265
                result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
1266
                self.assertTrue(isinstance(result, FakeTensor))
1267
                self.assertEqual(result.shape, (5, 2))
1268
                # This case uses the **different** fake tensor modes to
1269
                #  1. create fake parameters and fake buffers, and
1270
                #  2. run FakeTensorProp
1271
                # The following code should fail.
1272
                failed = False
1273
                try:
1274
                    FakeTensorProp(graph_model).propagate(value)
1275
                except AssertionError:
1276
                    # AssertionError: tensor's device must be `meta`, got cpu instead
1277
                    failed = True
1278
                self.assertTrue(failed)
1279

1280

1281
    def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
1282
        class OptionalArgumentInBetween(torch.nn.Module):
1283
            def __init__(self):
1284
                super().__init__()
1285
                self.layer1 = torch.nn.Linear(4, 3)
1286
                self.layer2 = torch.nn.Linear(3, 2)
1287

1288
            def forward(self, value, another_value=None, another_optional_value=None):
1289
                # Mimic huggingface's `forward` methods which have several optional arguments.
1290
                # For example, GPT accepts forward(self, input_ids, None, attention_mask, ...).
1291
                # To apply FakeTensorProp, its from_real_tensor(...) needs to accept None.
1292
                if another_value is None:
1293
                    another_value = torch.rand_like(value)
1294
                if another_optional_value is None:
1295
                    another_optional_value = torch.rand_like(value)
1296
                value = value + another_value + another_optional_value
1297
                return value * value
1298

1299
        fake_mode = FakeTensorMode(allow_non_fake_inputs=True, allow_fallback_kernels=False)
1300
        with fake_mode:
1301
            model = OptionalArgumentInBetween()
1302
            value = torch.randn(5, 4)
1303
            another_optional_value = torch.randn(5, 4)
1304
            graph_model = torch.fx.symbolic_trace(model, (value, None, another_optional_value))
1305
            FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)
1306

1307

1308
    def test_torch_load_with_fake_mode(self):
1309

1310
        class TheModelClass(torch.nn.Module):
1311
            def __init__(self):
1312
                super().__init__()
1313
                self.fc1 = torch.nn.Linear(5, 10)
1314

1315
            def forward(self, x):
1316
                return self.fc1(x)
1317

1318
        with TemporaryFileName() as state_dict_file:
1319
            # Create state_dict to be loaded later
1320
            model = TheModelClass()
1321
            torch.save(model.state_dict(), state_dict_file)
1322

1323
            fake_mode = FakeTensorMode()
1324
            with fake_mode:
1325
                torch.load(state_dict_file)  # scenario 1
1326
                torch.load(state_dict_file, map_location="cpu")  # scenario 2
1327

1328

1329
class FakeTensorSerialization(TestCase):
1330
    def test_serialization(self):
1331
        x = torch.tensor([0], device="cpu")
1332
        with FakeTensorMode():
1333
            y = pickle.loads(pickle.dumps(x))
1334
            self.assertEqual(type(y), FakeTensor)
1335
            self.assertEqual(y.device.type, "meta")
1336

1337
            with unset_fake_temporarily():
1338
                y = pickle.loads(pickle.dumps(x))
1339
                self.assertEqual(x.device, y.device)
1340

1341
    def test_serialization_with_tracing(self):
1342
        x = torch.tensor([0], device="cpu")
1343
        with tracing(TracingContext(FakeTensorMode())):
1344
            y = pickle.loads(pickle.dumps(x))
1345
            self.assertEqual(x.device, y.device)
1346

1347

1348
class FakeTensorDispatchCache(TestCase):
1349
    def test_shape_env_settings(self):
1350
        """
1351
        Validation that any boolean settings in ShapeEnv are present in the
1352
        _ShapeEnvSettings. We hope to ensure that any new settings that might
1353
        affect FakeTensor dispatch are included in the cache key calculation.
1354
        If this test fails, consider updating _ShapeEnvSettings or change this
1355
        test to omit checking for the new field.
1356
        """
1357
        init_sig = inspect.signature(ShapeEnv._init)
1358
        args = [
1359
            name for name, param in init_sig.parameters.items()
1360
            if type(param.default) is bool
1361
        ]
1362

1363
        settings = [f.name for f in dataclasses.fields(_ShapeEnvSettings)]
1364
        for arg in args:
1365
            self.assertTrue(arg in settings)
1366

1367
    def _test_cache_key(self, fm, x, y, z):
1368
        """
1369
        Helper for all test_cache_key_* tests below. Assert that the
1370
        cache keys for inputs x and y are the same, but z is different.
1371
        """
1372
        func = aten.add.Tensor
1373
        key_x = fm._cache_key(func, [x], {})
1374
        key_y = fm._cache_key(func, [y], {})
1375
        key_z = fm._cache_key(func, [z], {})
1376

1377
        self.assertEqual(key_x, key_y)
1378
        self.assertNotEqual(key_x, key_z)
1379

1380
    def test_cache_key_dtype(self):
1381
        with FakeTensorMode() as fm:
1382
            x = torch.randn(4, 3, dtype=torch.float16)
1383
            y = torch.randn(4, 3, dtype=torch.float16)
1384
            z = x.to(dtype=torch.float32)
1385
            self._test_cache_key(fm, x, y, z)
1386

1387
    def test_cache_key_shape(self):
1388
        with FakeTensorMode() as fm:
1389
            x = torch.randn(4, 3)
1390
            y = torch.randn(4, 3)
1391
            z = torch.randn(4, 2)
1392
            self._test_cache_key(fm, x, y, z)
1393

1394
    def test_cache_key_stride(self):
1395
        with FakeTensorMode() as fm:
1396
            x = torch.randn(4, 2)
1397
            y = torch.randn(4, 2)
1398
            z = x.as_strided((4, 2), (1, 2))
1399
            self._test_cache_key(fm, x, y, z)
1400

1401
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1402
    def test_cache_key_device(self):
1403
        with FakeTensorMode() as fm:
1404
            x = torch.randn(4, 3)
1405
            y = torch.randn(4, 3)
1406
            z = x.to(device="cuda")
1407
            self._test_cache_key(fm, x, y, z)
1408

1409
    def test_cache_key_memory_format(self):
1410
        with FakeTensorMode() as fm:
1411
            x = torch.randn(1, 2, 3, 4)
1412
            y = torch.randn(1, 2, 3, 4)
1413
            z = x.to(memory_format=torch.channels_last)
1414
            self._test_cache_key(fm, x, y, z)
1415

1416
    def test_cache_key_storage_offset(self):
1417
        with FakeTensorMode() as fm:
1418
            x = torch.randn(3)[1:]
1419
            y = torch.randn(3)[1:]
1420
            z = torch.randn(2)
1421
            self._test_cache_key(fm, x, y, z)
1422

1423
    def test_cache_key_requires_grad(self):
1424
        with FakeTensorMode() as fm:
1425
            x = torch.randn(4, 3)
1426
            y = torch.randn(4, 3)
1427
            z = torch.randn(4, 3, requires_grad=True)
1428
            self._test_cache_key(fm, x, y, z)
1429

1430
    def test_cache_key_is_conj(self):
1431
        with FakeTensorMode() as fm:
1432
            x = torch.randn(4, 3, dtype=torch.complex64)
1433
            y = torch.randn(4, 3, dtype=torch.complex64)
1434
            z = torch.randn(4, 3, dtype=torch.complex64)
1435
            torch._C._set_conj(z, not z.is_conj())
1436
            self._test_cache_key(fm, x, y, z)
1437

1438
    def test_cache_key_is_neg(self):
1439
        with FakeTensorMode() as fm:
1440
            x = torch.randn(4, 3, dtype=torch.complex64)
1441
            y = torch.randn(4, 3, dtype=torch.complex64)
1442
            z = torch.randn(4, 3, dtype=torch.complex64)
1443
            torch._C._set_neg(z, not z.is_neg())
1444
            self._test_cache_key(fm, x, y, z)
1445

1446
    def test_cache_key_is_inference(self):
1447
        with torch.inference_mode(True):
1448
            t = torch.randn(4, 3)
1449
        with FakeTensorMode() as fm:
1450
            x = torch.randn(4, 3)
1451
            y = torch.randn(4, 3)
1452
            z = fm.from_tensor(t)
1453
            self._test_cache_key(fm, x, y, z)
1454

1455
    def test_cache_key_constants(self):
1456
        with FakeTensorMode() as fm:
1457
            # Python hashes 1.0 to the same value as 1. Make sure the
1458
            # cache key calculation differentiates them.
1459
            self._test_cache_key(fm, 1.0, 1.0, 1)
1460
            self._test_cache_key(fm, 0.0, 0.0, 0)
1461

1462
    def assertHitsMisses(self, hits, misses):
1463
        """
1464
        Helper to assert on the number of recorded hits and misses.
1465
        """
1466
        info = FakeTensorMode.cache_info()
1467
        self.assertEqual(info.hits, hits)
1468
        self.assertEqual(info.misses, misses)
1469

1470
    def assertBypasses(self, reason, count):
1471
        """
1472
        Helper to assert on the number of recorded bypasses.
1473
        """
1474
        info = FakeTensorMode.cache_info()
1475
        if count > 0:
1476
            self.assertIn(reason, info.bypasses)
1477
            self.assertEqual(info.bypasses[reason], count)
1478
        else:
1479
            self.assertNotIn(reason, info.bypasses)
1480

1481
    def test_cache_hit(self):
1482
        """
1483
        Test that cache hit/miss counters are updated correctly.
1484
        """
1485
        with FakeTensorMode():
1486
            x = torch.randn(4, 3)
1487
            y = torch.randn(4, 3)
1488

1489
            FakeTensorMode.cache_clear()
1490
            self.assertHitsMisses(0, 0)
1491
            res1 = x + y
1492
            self.assertHitsMisses(0, 1)
1493
            res2 = x + y
1494
            self.assertHitsMisses(1, 1)
1495

1496
            self.assertEqual(
1497
                extract_tensor_metadata(res1),
1498
                extract_tensor_metadata(res2),
1499
            )
1500

1501
    def test_cache_bypass(self):
1502
        """
1503
        Test that cache bypass counters are updated correctly.
1504
        """
1505
        with FakeTensorMode():
1506
            x = torch.randn(1, 2)
1507

1508
            FakeTensorMode.cache_clear()
1509
            self.assertBypasses("inplace view", 0)
1510

1511
            x.unsqueeze_(0)
1512
            self.assertBypasses("inplace view", 1)
1513

1514
    def test_cache_default_dtype(self):
1515
        """
1516
        Test that the default dtype is respected when serving cached results.
1517
        """
1518
        with FakeTensorMode():
1519
            x = torch.tensor([1, 2], dtype=torch.int32)
1520
            torch.set_default_dtype(torch.float32)
1521

1522
            FakeTensorMode.cache_clear()
1523
            self.assertHitsMisses(0, 0)
1524

1525
            y = x + 1.0
1526
            self.assertEqual(y.dtype, torch.float32)
1527
            self.assertHitsMisses(0, 1)
1528

1529
            torch.set_default_dtype(torch.float16)
1530
            y = x + 1.0
1531
            self.assertEqual(y.dtype, torch.float16)
1532
            self.assertHitsMisses(0, 2)
1533

1534
            torch.set_default_dtype(torch.float32)
1535
            y = x + 1.0
1536
            self.assertEqual(y.dtype, torch.float32)
1537
            self.assertHitsMisses(1, 2)
1538

1539
    @unittest.skipIf(not RUN_CUDA, "requires cuda")
1540
    def test_cache_default_device(self):
1541
        """
1542
        Test that the default device is respected when serving cached results.
1543
        """
1544
        with FakeTensorMode():
1545
            FakeTensorMode.cache_clear()
1546
            self.assertHitsMisses(0, 0)
1547

1548
            torch.set_default_device("cpu")
1549
            x = torch.tensor([1, 2])
1550
            y = x + 1.0
1551
            self.assertEqual(y.device.type, "cpu")
1552
            self.assertHitsMisses(0, 1)
1553

1554
            torch.set_default_device("cuda")
1555
            x = torch.tensor([1, 2])
1556
            y = x + 1.0
1557
            self.assertEqual(y.device.type, "cuda")
1558
            self.assertHitsMisses(0, 2)
1559

1560
            torch.set_default_device("cpu")
1561
            x = torch.tensor([1, 2])
1562
            y = x + 1.0
1563
            self.assertEqual(y.device.type, "cpu")
1564
            self.assertHitsMisses(1, 2)
1565

1566
    def test_cache_inplace_op(self):
1567
        """
1568
        Test that inplace ops served from the cache correctly reference the
1569
        input parameter.
1570
        """
1571
        with FakeTensorMode():
1572
            x = torch.randn(1, 2)
1573
            y = torch.randn(1, 2)
1574

1575
            FakeTensorMode.cache_clear()
1576
            self.assertHitsMisses(0, 0)
1577

1578
            z = x.add_(y)
1579
            self.assertHitsMisses(0, 1)
1580
            self.assertEqual(id(x), id(z))
1581

1582
            w = x.add_(y)
1583
            self.assertHitsMisses(1, 1)
1584
            self.assertEqual(id(x), id(w))
1585

1586
    def test_cache_view_op(self):
1587
        """
1588
        Test that view ops are handled correctly when served from the cache.
1589
        """
1590
        with FakeTensorMode():
1591
            x1 = torch.ones(2, requires_grad=True).clone()
1592
            x2 = torch.ones(2, requires_grad=True).clone()
1593
            y2 = x2.view(-1)
1594

1595
            # Test operating on a non-view tensor, then the same operation
1596
            # on a view tensor. Assert that the view property is set correctly.
1597
            z1 = x1.mul_(2)
1598
            self.assertFalse(z1._is_view())
1599

1600
            z2 = y2.mul_(2)
1601
            self.assertTrue(z2._is_view())
1602

1603
            # Now the other way around: first operate on a view tensor, then
1604
            # the same operation on a non-view tensor.
1605
            z2 = y2.mul_(2)
1606
            self.assertTrue(z2._is_view())
1607

1608
            z1 = x1.mul_(2)
1609
            self.assertFalse(z1._is_view())
1610

1611
    def test_cache_dispatch_key_set(self):
1612
        """
1613
        Test that operations that change the dispatch key set bypass caching.
1614
        """
1615
        with FakeTensorMode():
1616
            FakeTensorMode.cache_clear()
1617
            self.assertBypasses("dispatch_key_set mismatch", 0)
1618

1619
            x = torch._efficientzerotensor(3)
1620
            self.assertTrue(x._is_zerotensor())
1621
            self.assertBypasses("dispatch_key_set mismatch", 1)
1622

1623
            y = torch._efficientzerotensor(3)
1624
            self.assertTrue(y._is_zerotensor())
1625
            self.assertBypasses("dispatch_key_set mismatch", 2)
1626

1627
    def test_inference_mode(self):
1628
        """
1629
        Test that caching handles inference mode correctly.
1630
        """
1631
        with FakeTensorMode():
1632
            x = torch.randn(4, 3)
1633
            y = torch.randn(4, 3)
1634

1635
            FakeTensorMode.cache_clear()
1636
            self.assertHitsMisses(0, 0)
1637

1638
            # Expect a miss when the inference mode is different
1639
            res1 = x + y
1640
            with torch.inference_mode():
1641
                res2 = x + y
1642

1643
            self.assertHitsMisses(0, 2)
1644
            self.assertFalse(res1.is_inference())
1645
            self.assertTrue(res2.is_inference())
1646

1647
            # Second tries should see hits
1648
            res3 = x + y
1649

1650
            self.assertHitsMisses(1, 2)
1651
            self.assertFalse(res3.is_inference())
1652
            self.assertEqual(
1653
                extract_tensor_metadata(res1),
1654
                extract_tensor_metadata(res3),
1655
            )
1656

1657
            with torch.inference_mode():
1658
                res4 = x + y
1659

1660
            self.assertHitsMisses(2, 2)
1661
            self.assertTrue(res4.is_inference())
1662
            self.assertEqual(
1663
                extract_tensor_metadata(res2),
1664
                extract_tensor_metadata(res4),
1665
            )
1666

1667

1668
instantiate_parametrized_tests(FakeTensorTest)
1669

1670
only_for = ("cpu", "cuda")
1671
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for)
1672

1673
if __name__ == "__main__":
1674
    run_tests()
1675

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.