pytorch

Форк
0
/
test_cuda.py 
5364 строки · 206.8 Кб
1
# Owner(s): ["module: cuda"]
2

3
import collections
4
import contextlib
5
import ctypes
6
import gc
7
import json
8
import os
9
import pickle
10
import random
11
import subprocess
12
import sys
13
import tempfile
14
import threading
15
import unittest
16
import warnings
17
from copy import deepcopy
18
from itertools import product
19
from random import randint
20

21
import psutil
22

23
import torch
24
import torch.cuda
25
from torch import inf, nan
26
from torch.cuda._memory_viz import (
27
    _profile_to_snapshot,
28
    profile_plot,
29
    segment_plot,
30
    trace_plot,
31
)
32
from torch.testing._internal.autocast_test_lists import AutocastTestLists
33
from torch.testing._internal.common_cuda import (
34
    _create_scaling_case,
35
    _get_torch_cuda_version,
36
    TEST_CUDNN,
37
    TEST_MULTIGPU,
38
)
39
from torch.testing._internal.common_device_type import (
40
    instantiate_device_type_tests,
41
    onlyCUDA,
42
    onlyNativeDeviceTypes,
43
)
44
from torch.testing._internal.common_optimizers import (
45
    _get_optim_inputs_including_global_cliquey_kwargs,
46
    optim_db,
47
    optims,
48
    TensorTracker,
49
)
50
from torch.testing._internal.common_utils import (
51
    EXPANDABLE_SEGMENTS,
52
    freeze_rng_state,
53
    gcIfJetson,
54
    get_cycles_per_ms,
55
    instantiate_parametrized_tests,
56
    IS_ARM64,
57
    IS_FBCODE,
58
    IS_JETSON,
59
    IS_LINUX,
60
    IS_SANDCASTLE,
61
    IS_WINDOWS,
62
    load_tests,
63
    NO_MULTIPROCESSING_SPAWN,
64
    NoTest,
65
    parametrize,
66
    run_tests,
67
    serialTest,
68
    skipCUDAMemoryLeakCheckIf,
69
    skipCUDANonDefaultStreamIf,
70
    skipIfRocm,
71
    slowTest,
72
    subtest,
73
    TemporaryFileName,
74
    TEST_CUDA,
75
    TEST_CUDA_GRAPH,
76
    TEST_NUMPY,
77
    TEST_WITH_ROCM,
78
    TestCase,
79
)
80
from torch.utils.checkpoint import checkpoint_sequential
81
from torch.utils.viz._cycles import observe_tensor_cycles
82

83

84
# load_tests from common_utils is used to automatically filter tests for
85
# sharding on sandcastle. This line silences flake warnings
86
load_tests = load_tests
87

88
if not TEST_CUDA:
89
    print("CUDA not available, skipping tests", file=sys.stderr)
90
    TestCase = NoTest  # noqa: F811
91

92
try:
93
    import torchvision.models  # noqa: F401
94
    from torchvision.models import resnet18  # noqa: F401
95

96
    HAS_TORCHVISION = True
97
except ImportError:
98
    HAS_TORCHVISION = False
99
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
100

101
TEST_CUDAMALLOCASYNC = TEST_CUDA and (
102
    torch.cuda.get_allocator_backend() == "cudaMallocAsync"
103
)
104
TEST_LARGE_TENSOR = TEST_CUDA
105
TEST_MEDIUM_TENSOR = TEST_CUDA
106
TEST_BF16 = False
107
TEST_PYNVML = not torch.cuda._HAS_PYNVML
108
if TEST_CUDA:
109
    TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
110
    TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9
111
    TEST_BF16 = torch.cuda.is_bf16_supported()
112

113
_cycles_per_ms = None
114

115

116
@torch.testing._internal.common_utils.markDynamoStrictTest
117
class TestCuda(TestCase):
118
    _do_cuda_memory_leak_check = True
119
    _do_cuda_non_default_stream = True
120
    FIFTY_MIL_CYCLES = 50000000
121

122
    def setUp(self):
123
        super().setUp()
124
        self.autocast_lists = AutocastTestLists(torch.device("cuda:0"))
125

126
    def tearDown(self):
127
        del self.autocast_lists
128
        super().tearDown()
129

130
    @property
131
    def expandable_segments(self):
132
        return EXPANDABLE_SEGMENTS
133

134
    def test_pinned_memory_with_cudaregister(self):
135
        torch.cuda.memory._set_allocator_settings(
136
            "pinned_use_cuda_host_register:True,pinned_num_register_threads:8"
137
        )
138
        t = torch.ones(20)
139
        self.assertFalse(t.is_pinned())
140
        try:
141
            pinned_t = torch.ones(1 << 21).pin_memory()
142
            self.assertTrue(pinned_t.is_pinned())
143
            pinned_t = torch.ones(1 << 24).pin_memory()
144
            self.assertTrue(pinned_t.is_pinned())
145
        except RuntimeError as e:
146
            # Some GPUs don't support same address space on host and device side
147
            pass
148

149
    def test_pinned_memory_with_cudaregister_multithread(self):
150
        num_threads = 4
151
        threads = [
152
            threading.Thread(target=self.test_pinned_memory_with_cudaregister)
153
            for t in range(num_threads)
154
        ]
155
        for thread in threads:
156
            thread.start()
157
        for thread in threads:
158
            thread.join()
159

160
    def test_pinned_memory_empty_cache(self):
161
        for alloc_settings in (True, False):
162
            torch.cuda.memory._set_allocator_settings(
163
                f"pinned_use_cuda_host_register:{alloc_settings}"
164
            )
165
            try:
166
                t = torch.ones(1024 * 1024, pin_memory=True)
167
                self.assertTrue(t.is_pinned())
168
                del t
169
                torch._C._host_emptyCache()
170
            except RuntimeError as e:
171
                # Some GPUs don't support same address space on host and device side
172
                pass
173

174
    def test_cudart_register(self):
175
        t = torch.ones(20)
176
        self.assertFalse(t.is_pinned())
177
        cudart = torch.cuda.cudart()
178
        r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
179
        self.assertEqual(r, 0)
180
        self.assertTrue(t.is_pinned())
181
        r = cudart.cudaHostUnregister(t.data_ptr())
182
        self.assertEqual(r, 0)
183
        self.assertFalse(t.is_pinned())
184

185
    def test_memory_allocation(self):
186
        gc.collect()
187
        torch.cuda.empty_cache()
188
        mem = None
189
        size = 1
190
        prev = 0
191
        try:
192
            prev = torch.cuda.memory_allocated()
193
            mem = torch.cuda.caching_allocator_alloc(size)
194
            self.assertGreater(torch.cuda.memory_allocated(), prev)
195
        finally:
196
            if mem is not None:
197
                torch.cuda.caching_allocator_delete(mem)
198
                self.assertEqual(torch.cuda.memory_allocated(), prev)
199

200
    def test_check_error(self):
201
        # Assert this call doesn't raise.
202
        torch.cuda.check_error(0)
203

204
        with self.assertRaisesRegex(
205
            torch.cuda.CudaError, "out of memory|hipErrorOutOfMemory"
206
        ):
207
            torch.cuda.check_error(2)
208

209
    def test_cuda_get_device_name(self):
210
        # Testing the behaviour with None as an argument
211
        current_device = torch.cuda.current_device()
212
        current_device_name = torch.cuda.get_device_name(current_device)
213
        device_name_None = torch.cuda.get_device_name(None)
214
        self.assertEqual(current_device_name, device_name_None)
215

216
        # Testing the behaviour for No argument
217
        device_name_no_argument = torch.cuda.get_device_name()
218
        self.assertEqual(current_device_name, device_name_no_argument)
219

220
    def test_cuda_get_device_capability(self):
221
        # Testing the behaviour with None as an argument
222
        current_device = torch.cuda.current_device()
223
        current_device_capability = torch.cuda.get_device_capability(current_device)
224
        device_capability_None = torch.cuda.get_device_capability(None)
225
        self.assertEqual(current_device_capability, device_capability_None)
226

227
        # Testing the behaviour for No argument
228
        device_capability_no_argument = torch.cuda.get_device_capability()
229
        self.assertEqual(current_device_capability, device_capability_no_argument)
230

231
    def test_out_of_memory(self):
232
        tensor = torch.zeros(1024, device="cuda")
233

234
        oom_regex = (
235
            "would exceed allowed memory"
236
            if TEST_CUDAMALLOCASYNC
237
            else "Tried to allocate 800000000.00 GiB"
238
        )
239
        with self.assertRaisesRegex(RuntimeError, oom_regex):
240
            torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="cuda")
241

242
        with self.assertRaisesRegex(
243
            RuntimeError, "Tried to allocate more than 1EB memory"
244
        ):
245
            torch.empty(
246
                1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="cuda"
247
            )
248

249
        # ensure out of memory error doesn't disturb subsequent kernel
250
        tensor.fill_(1)
251
        self.assertTrue((tensor == 1).all())
252

253
    @unittest.skipIf(
254
        TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)"
255
    )
256
    @serialTest()
257
    def test_out_of_memory_retry(self):
258
        torch.cuda.empty_cache()
259
        total_memory = torch.cuda.get_device_properties(0).total_memory
260
        oom_regex = (
261
            "would exceed allowed memory"
262
            if TEST_CUDAMALLOCASYNC
263
            else "Tried to allocate"
264
        )
265
        size = int(total_memory * 0.5)
266
        a = torch.empty(size, dtype=torch.int8, device="cuda")
267
        with self.assertRaisesRegex(RuntimeError, oom_regex):
268
            b = torch.empty(size, dtype=torch.int8, device="cuda")
269
        del a
270
        b = torch.empty(size, dtype=torch.int8, device="cuda")
271
        del b
272
        # We used a lot of memory here, clean up so we don't affect other tests too much
273
        torch.cuda.empty_cache()
274
        torch.cuda.reset_peak_memory_stats()
275

276
    @serialTest()
277
    def test_set_per_process_memory_fraction(self):
278
        # test invalid fraction value.
279
        with self.assertRaisesRegex(TypeError, "Invalid type"):
280
            torch.cuda.set_per_process_memory_fraction(1)
281
        with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
282
            torch.cuda.set_per_process_memory_fraction(-0.1)
283
        with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
284
            torch.cuda.set_per_process_memory_fraction(2.0)
285

286
        tensor = torch.zeros(1024, device="cuda")
287
        torch.cuda.empty_cache()
288
        total_memory = torch.cuda.get_device_properties(0).total_memory
289
        torch.cuda.set_per_process_memory_fraction(0.5, 0)
290

291
        # test 0.499 allocation is ok.
292
        application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved()
293
        tmp_tensor = torch.empty(application, dtype=torch.int8, device="cuda")
294
        del tmp_tensor
295
        torch.cuda.empty_cache()
296

297
        application = int(total_memory * 0.5)
298
        # it will get OOM when try to allocate more than half memory.
299
        oom_regex = (
300
            "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
301
        )
302
        with self.assertRaisesRegex(RuntimeError, oom_regex):
303
            torch.empty(application, dtype=torch.int8, device="cuda")
304

305
        # ensure out of memory error doesn't disturb subsequent kernel
306
        tensor.fill_(1)
307
        self.assertTrue((tensor == 1).all())
308

309
    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "uuid attribute not yet available")
310
    def test_uuid(self):
311
        uuid = torch.cuda.get_device_properties(0).uuid
312
        self.assertEqual(len(str(uuid)), 36)  # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
313
        self.assertEqual(len(uuid.bytes), 16)
314

315
    def test_copy_non_blocking(self):
316
        def _test_copy_non_blocking(a, b):
317
            event = torch.cuda.Event()
318
            a.copy_(b, non_blocking=True)
319
            event.record()
320
            event.synchronize()
321
            self.assertEqual(a, b)
322

323
        # 10MB copies
324
        x = torch.ones(10000000, dtype=torch.uint8).cuda()
325
        y = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
326
        _test_copy_non_blocking(x, y)
327

328
        x = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
329
        y = torch.ones(10000000, dtype=torch.uint8).cuda()
330
        _test_copy_non_blocking(x, y)
331

332
        # Test the case where the pinned data_ptr is not equal to the storage data_ptr.
333
        x_base = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
334
        x = x_base[1:]
335
        self.assertTrue(x.is_pinned())
336
        self.assertTrue(x_base.is_pinned())
337
        self.assertNotEqual(x_base.data_ptr(), x.data_ptr())
338
        self.assertEqual(x_base.storage().data_ptr(), x.storage().data_ptr())
339
        y = torch.ones(10000000 - 1, dtype=torch.uint8).cuda()
340
        _test_copy_non_blocking(x, y)
341

342
    def test_copy_non_blocking_type_conversion(self):
343
        a = torch.ones(1, device="cuda")
344
        b = torch.zeros(1, device="cpu", pin_memory=True)
345
        c = torch.empty(1, device="cuda", dtype=torch.long)
346
        torch.cuda._sleep(int(100 * get_cycles_per_ms()))
347
        b.copy_(a, non_blocking=True)
348
        c.copy_(b, non_blocking=True)
349
        self.assertEqual(a, c, exact_dtype=False)
350

351
    @serialTest()
352
    def test_to_non_blocking(self):
353
        stream = torch.cuda.current_stream()
354

355
        def _test_to_non_blocking(a, non_blocking, dst):
356
            torch.cuda.synchronize()
357
            # Pushes an 0.1 second spin to stream so if the copy is non blocking,
358
            # stream will almost surely be active when we query().
359
            torch.cuda._sleep(int(100 * get_cycles_per_ms()))
360
            b = a.to(device=dst, non_blocking=non_blocking)
361
            self.assertEqual(stream.query(), not non_blocking)
362
            stream.synchronize()
363
            self.assertEqual(a, b)
364
            self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu"))
365

366
        for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)):
367
            # Creates source on the opposite device from destination.
368
            src = torch.randn(
369
                1000000,
370
                device="cuda" if dst == "cpu" else "cpu",
371
                pin_memory=True if dst == "cuda" else False,
372
            )
373
            _test_to_non_blocking(src, try_non_blocking, dst)
374

375
    def test_to_cpu_blocking_by_default(self):
376
        src = torch.randn(1000000, device="cuda")
377
        torch.cuda.synchronize()
378
        torch.cuda._sleep(int(100 * get_cycles_per_ms()))
379
        dst = src.to(device="cpu")
380
        self.assertEqual(torch.cuda.current_stream().query(), True)
381
        self.assertEqual(src, dst)
382
        self.assertFalse(dst.is_pinned())
383

384
    def test_serialization_array_with_storage(self):
385
        x = torch.randn(5, 5).cuda()
386
        y = torch.IntTensor(2, 5).fill_(0).cuda()
387
        q = [x, y, x, y.storage()]
388
        with tempfile.NamedTemporaryFile() as f:
389
            torch.save(q, f)
390
            f.seek(0)
391
            q_copy = torch.load(f)
392
        self.assertEqual(q_copy, q, atol=0, rtol=0)
393
        q_copy[0].fill_(5)
394
        self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
395
        self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor))
396
        self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
397
        self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
398
        self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
399
        self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
400
        q_copy[1].fill_(10)
401
        self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
402

403
    @unittest.skipIf(
404
        TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async"
405
    )
406
    @unittest.skipIf(
407
        _get_torch_cuda_version() >= (12, 2),
408
        "skipped as explicit workspace allocation is removed",
409
    )
410
    def test_cublas_workspace_explicit_allocation(self):
411
        a = torch.randn(7, 7, device="cuda", requires_grad=False)
412
        default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024  # :4096:2:16:8
413
        # different size (32 MiB) expected on Hopper GPU
414
        if torch.cuda.get_device_capability() == (9, 0):
415
            default_workspace_size = 4096 * 8 * 1024
416

417
        def check_workspace_size(inp):
418
            torch._C._cuda_clearCublasWorkspaces()
419
            start = torch.cuda.memory_stats()["active_bytes.all.allocated"]
420
            with torch.no_grad():
421
                torch.matmul(inp, inp)
422
            finish = torch.cuda.memory_stats()["active_bytes.all.allocated"]
423
            return finish - start
424

425
        # check default
426
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
427
        self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288)
428

429
        # check default with bad user config
430
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = "-1"
431
        self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288)
432

433
        # check valid config
434
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":128:8:64:16:32:32"
435
        self.assertTrue(abs(check_workspace_size(a) - (3072 * 1024)) < 524288)
436

437
        torch._C._cuda_clearCublasWorkspaces()
438

439
    def test_cublas_allow_tf32_get_set(self):
440
        skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
441
            os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
442
        )
443
        if skip_tf32_cublas:
444
            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
445
            return
446

447
        orig = torch.backends.cuda.matmul.allow_tf32
448
        self.assertEqual(torch._C._get_cublas_allow_tf32(), orig)
449
        torch.backends.cuda.matmul.allow_tf32 = not orig
450
        self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
451
        torch.backends.cuda.matmul.allow_tf32 = orig
452

453
    def test_float32_matmul_precision_get_set(self):
454
        orig = torch.get_float32_matmul_precision()
455
        skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
456
            os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
457
        )
458
        # this is really just checking that the environment variable is respected during testing
459
        # and not overwritten by another function that doesn't revert it to the intitial value
460
        if not skip_tf32_cublas:
461
            self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
462
            self.assertEqual(torch.get_float32_matmul_precision(), "highest")
463
        else:
464
            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
465
        for p in ("medium", "high"):
466
            torch.set_float32_matmul_precision(p)
467
            self.assertEqual(torch.get_float32_matmul_precision(), p)
468
            self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
469
        torch.set_float32_matmul_precision("highest")
470
        self.assertEqual(torch.get_float32_matmul_precision(), "highest")
471
        self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
472
        torch.set_float32_matmul_precision(orig)
473

474
    def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
475
        orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
476
        self.assertEqual(
477
            torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig
478
        )
479
        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig
480
        self.assertEqual(
481
            torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig
482
        )
483
        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
484

485
    def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self):
486
        orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
487
        self.assertEqual(
488
            torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig
489
        )
490
        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig
491
        self.assertEqual(
492
            torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig
493
        )
494
        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
495

496
    def test_cudnn_allow_tf32_get_set(self):
497
        with torch.backends.cudnn.flags(
498
            enabled=None, benchmark=None, deterministic=None, allow_tf32=False
499
        ):
500
            self.assertFalse(torch.backends.cudnn.allow_tf32)
501
        with torch.backends.cudnn.flags(
502
            enabled=None, benchmark=None, deterministic=None, allow_tf32=True
503
        ):
504
            self.assertTrue(torch.backends.cudnn.allow_tf32)
505

506
    def test_type_conversions(self):
507
        x = torch.randn(5, 5)
508
        self.assertIsInstance(x.float(), torch.FloatTensor)
509
        self.assertIsInstance(x.cuda().double(), torch.cuda.DoubleTensor)
510
        self.assertIsInstance(x.cuda().float(), torch.cuda.FloatTensor)
511
        self.assertIsInstance(x.cuda().float().cpu(), torch.FloatTensor)
512
        self.assertIsInstance(x.cuda().float().cpu().int(), torch.IntTensor)
513

514
        y = x.storage()
515
        self.assertIsInstance(y.float(), torch.FloatStorage)
516
        self.assertIsInstance(y.cuda().double(), torch.cuda.DoubleStorage)
517
        self.assertIsInstance(y.cuda().float(), torch.cuda.FloatStorage)
518
        self.assertIsInstance(y.cuda().float().cpu(), torch.FloatStorage)
519
        self.assertIsInstance(y.cuda().float().cpu().int(), torch.IntStorage)
520

521
    @unittest.skip("was disabled due to not enough memory, but actually it always fail")
522
    def test_arithmetic_large_tensor(self):
523
        x = torch.empty(2**30, device="cuda")
524

525
        x.fill_(1)
526
        self.assertEqual(x.sum(), 2**30)
527

528
        x += 1
529
        self.assertEqual(x.sum(), 2**31)
530

531
        x.fill_(1)
532
        x -= 0.5
533
        self.assertEqual(x.sum(), 2**29)
534

535
        x.fill_(1)
536
        x *= 2
537
        self.assertEqual(x.sum(), 2**31)
538

539
        x.fill_(1)
540
        x /= 2
541
        self.assertEqual(x.sum(), 2**29)
542

543
    def test_gather_bool(self):
544
        t = torch.tensor([[False, True], [True, True]], device="cuda")
545
        self.assertEqual(
546
            torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]], device="cuda")),
547
            torch.tensor([[False, False], [True, True]], device="cuda"),
548
        )
549

550
    def test_torch_manual_seed_seeds_cuda_devices(self):
551
        with freeze_rng_state():
552
            x = torch.zeros(4, 4).float().cuda()
553
            torch.manual_seed(2)
554
            self.assertEqual(torch.cuda.initial_seed(), 2)
555
            x.uniform_()
556
            torch.manual_seed(2)
557
            y = x.clone().uniform_()
558
            self.assertEqual(x, y)
559
            self.assertEqual(torch.cuda.initial_seed(), 2)
560

561
    def test_manual_seed(self):
562
        with freeze_rng_state():
563
            x = torch.zeros(4, 4).float().cuda()
564
            torch.cuda.manual_seed(2)
565
            self.assertEqual(torch.cuda.initial_seed(), 2)
566
            x.uniform_()
567
            a = torch.bernoulli(torch.full_like(x, 0.5))
568
            torch.cuda.manual_seed(2)
569
            y = x.clone().uniform_()
570
            b = torch.bernoulli(torch.full_like(x, 0.5))
571
            self.assertEqual(x, y)
572
            self.assertEqual(a, b)
573
            self.assertEqual(torch.cuda.initial_seed(), 2)
574

575
    def test_specify_improper_device_name(self):
576
        import os
577

578
        fname = "tempfile.pt"
579
        try:
580
            with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
581
                torch.save(
582
                    [torch.nn.Parameter(torch.randn(10, 10))],
583
                    fname,
584
                    _use_new_zipfile_serialization=True,
585
                )
586
                torch.load(fname, "cuda0")
587
        finally:
588
            if os.path.exists(fname):
589
                os.remove(fname)
590

591
    def test_get_device_index(self):
592
        from torch.cuda._utils import _get_device_index
593

594
        with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
595
            _get_device_index("cuda0", optional=True)
596

597
        with self.assertRaisesRegex(ValueError, "Expected a cuda device"):
598
            cpu_device = torch.device("cpu")
599
            _get_device_index(cpu_device, optional=True)
600

601
    def test_serialization_array_with_empty(self):
602
        x = [torch.randn(4, 4).cuda(), torch.cuda.FloatTensor()]
603
        with tempfile.NamedTemporaryFile() as f:
604
            torch.save(x, f)
605
            f.seek(0)
606
            x_copy = torch.load(f)
607
        for original, copy in zip(x, x_copy):
608
            self.assertEqual(copy, original)
609
            self.assertIs(type(copy), type(original))
610
            self.assertEqual(copy.get_device(), original.get_device())
611

612
    @skipCUDANonDefaultStreamIf(True)
613
    def test_streams(self):
614
        default_stream = torch.cuda.current_stream()
615
        user_stream = torch.cuda.Stream()
616
        self.assertEqual(torch.cuda.current_stream(), default_stream)
617
        self.assertNotEqual(default_stream, user_stream)
618
        self.assertEqual(default_stream.cuda_stream, 0)
619
        self.assertNotEqual(user_stream.cuda_stream, 0)
620
        with torch.cuda.stream(user_stream):
621
            self.assertEqual(torch.cuda.current_stream(), user_stream)
622
        self.assertTrue(user_stream.query())
623
        tensor1 = torch.ByteTensor(5).pin_memory()
624
        tensor2 = tensor1.cuda(non_blocking=True) + 1
625
        default_stream.synchronize()
626
        self.assertTrue(default_stream.query())
627

628
    def test_stream_event_repr(self):
629
        s = torch.cuda.current_stream()
630
        self.assertTrue("torch.cuda.Stream" in s.__repr__())
631
        e = torch.cuda.Event()
632
        self.assertTrue("torch.cuda.Event" in e.__repr__())
633
        s.record_event(e)
634
        self.assertTrue("torch.cuda.Event" in e.__repr__())
635

636
    def test_events(self):
637
        stream = torch.cuda.current_stream()
638
        event = torch.cuda.Event(enable_timing=True)
639
        self.assertTrue(event.query())
640
        start_event = torch.cuda.Event(enable_timing=True)
641
        stream.record_event(start_event)
642
        torch.cuda._sleep(int(50 * get_cycles_per_ms()))
643
        stream.record_event(event)
644
        self.assertFalse(event.query())
645
        event.synchronize()
646
        self.assertTrue(event.query())
647
        self.assertGreater(start_event.elapsed_time(event), 0)
648

649
    def test_generic_stream_event(self):
650
        stream = torch.Stream("cuda")
651
        self.assertEqual(stream.device_index, torch.cuda.current_device())
652
        cuda_stream = torch.cuda.Stream(
653
            stream_id=stream.stream_id,
654
            device_index=stream.device_index,
655
            device_type=stream.device_type,
656
        )
657
        self.assertEqual(stream.stream_id, cuda_stream.stream_id)
658
        self.assertNotEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
659

660
        event1 = torch.Event("cuda", enable_timing=True)
661
        event2 = torch.Event("cuda", enable_timing=True)
662
        self.assertEqual(event1.event_id, 0)
663
        a = torch.randn(1000)
664
        b = torch.randn(1000)
665
        with torch.cuda.stream(cuda_stream):
666
            a_cuda = a.to("cuda", non_blocking=True)
667
            b_cuda = b.to("cuda", non_blocking=True)
668
            self.assertEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
669
        event1.record(stream)
670
        event1.synchronize()
671
        self.assertTrue(event1.query())
672
        c_cuda = a_cuda + b_cuda
673
        event2.record()
674
        event2.synchronize()
675
        self.assertTrue(event2.query())
676
        self.assertNotEqual(event1.event_id, event2.event_id)
677
        self.assertEqual(c_cuda.cpu(), a + b)
678
        self.assertTrue(event1.elapsed_time(event2) > 0)
679

680
    def test_record_stream(self):
681
        cycles_per_ms = get_cycles_per_ms()
682

683
        t = torch.FloatTensor([1, 2, 3, 4]).pin_memory()
684
        result = torch.cuda.FloatTensor(t.size())
685
        stream = torch.cuda.Stream()
686
        ptr = [None]
687

688
        # Performs the CPU->GPU copy in a background stream
689
        def perform_copy():
690
            with torch.cuda.stream(stream):
691
                tmp = t.cuda(non_blocking=True)
692
                ptr[0] = tmp.data_ptr()
693
            torch.cuda.current_stream().wait_stream(stream)
694
            tmp.record_stream(torch.cuda.current_stream())
695
            torch.cuda._sleep(int(50 * cycles_per_ms))  # delay the copy
696
            result.copy_(tmp)
697

698
        perform_copy()
699
        with torch.cuda.stream(stream):
700
            tmp2 = torch.cuda.FloatTensor(t.size())
701
            tmp2.zero_()
702
            self.assertNotEqual(
703
                tmp2.data_ptr(), ptr[0], msg="allocation re-used to soon"
704
            )
705

706
        self.assertEqual(result.tolist(), [1, 2, 3, 4])
707

708
        if not TEST_CUDAMALLOCASYNC:
709
            # In the native allocator, we expect "tmp"'s side-stream-tagged block will be reused
710
            # in that side stream after result.copy_(tmp) in the main stream finishes.
711
            torch.cuda.current_stream().synchronize()
712
            with torch.cuda.stream(stream):
713
                tmp3 = torch.cuda.FloatTensor(t.size())
714
                self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not re-used")
715

716
    def test_record_stream_on_shifted_view(self):
717
        # See issue #27366
718

719
        # This test detects unexpected block reallocation. For reliable test,
720
        # the stream to allocate tensors is isolated. The allocator will not
721
        # reuse free blocks which were allocated from another stream.
722
        stream_alloc = torch.cuda.Stream()
723
        with torch.cuda.stream(stream_alloc):
724
            base = torch.cuda.FloatTensor([10, 10])
725

726
        # Record another stream on a shifted view tensor.
727
        view = base[5:]
728
        assert view.storage_offset() > 0
729

730
        stream_record = torch.cuda.Stream()
731
        with torch.cuda.stream(stream_record):
732
            torch.cuda._sleep(int(50 * get_cycles_per_ms()))
733

734
        view.record_stream(stream_record)
735

736
        # Delete those tensors to make the block free soon.
737
        data_ptr = base.data_ptr()
738
        del base, view
739

740
        # A new tensor should not be allocated to the block above.
741
        stream_alloc.synchronize()
742

743
        with torch.cuda.stream(stream_alloc):
744
            try_realloc = torch.cuda.FloatTensor([10, 10])
745

746
        self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
747

748
    def test_noncontiguous_pinned_memory(self):
749
        # See issue #3266
750
        x = torch.arange(0, 10).view((2, 5))
751
        self.assertEqual(x.t(), x.t().pin_memory())
752

753
    def test_caching_pinned_memory(self):
754
        cycles_per_ms = get_cycles_per_ms()
755

756
        # check that allocations are re-used after deletion
757
        t = torch.FloatTensor([1]).pin_memory()
758
        ptr = t.data_ptr()
759
        del t
760
        t = torch.FloatTensor([1]).pin_memory()
761
        self.assertEqual(t.data_ptr(), ptr, msg="allocation not reused")
762

763
        # check that the allocation is not re-used if it's in-use by a copy
764
        gpu_tensor = torch.cuda.FloatTensor([0])
765
        torch.cuda._sleep(int(1000 * cycles_per_ms))  # delay the copy by 1s
766
        gpu_tensor.copy_(t, non_blocking=True)
767
        del t
768
        t = torch.FloatTensor([1]).pin_memory()
769
        self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon")
770
        self.assertEqual(list(gpu_tensor), [1])
771

772
    def test_caching_allocator_record_stream_oom(self):
773
        """allocations delayed by a record_stream call should still be freed on
774
        an out-of-memory in cuda_malloc_retry. see issue #19219"""
775
        stream = torch.cuda.Stream()
776

777
        with torch.cuda.stream(stream):
778
            y = torch.zeros(40 * 1024 * 1024, device="cuda")
779

780
        for _ in range(100):
781
            x = torch.empty(40 * 1024 * 1024, device="cuda")
782
            with torch.cuda.stream(stream):
783
                y += x
784
            # delays re-use of `x` until after all operations in `stream`
785
            x.record_stream(stream)
786
            del x
787

788
        # we've made a mess by allocating up to the device capacity. free any
789
        # cached blocks in case it affects future tests.
790
        torch.cuda.empty_cache()
791

792
    # Tests for historic illegal memory access, see #17040.
793
    def test_reduction_gpu_memory_accessing(self):
794
        x = torch.ones(512, 8, dtype=torch.float32, device="cuda")
795
        torch.sum(x, 0)
796

797
    def test_sum_fp16(self):
798
        x = torch.zeros(10, device="cuda", dtype=torch.float16)
799
        self.assertEqual(x.sum(), 0)
800

801
        x = torch.ones(65504, device="cuda", dtype=torch.float16)
802
        self.assertEqual(x.sum(), 65504)
803
        self.assertEqual(x.sum(dtype=torch.float32), 65504)
804

805
        x = torch.ones(65536, device="cuda", dtype=torch.float16)
806
        self.assertEqual(x.sum(dtype=torch.float32), 65536)
807

808
        a = torch.zeros(1203611).bernoulli_(0.0005)
809
        x = a.to(device="cuda", dtype=torch.float16)
810
        self.assertEqual(x.sum().item(), a.sum().item())
811

812
        a = torch.zeros(100, 121, 80).bernoulli_(0.0005)
813
        x = a.to(device="cuda", dtype=torch.float16)
814
        self.assertEqual(x.sum((0, 2)).float().cpu(), a.sum((0, 2)))
815

816
    def test_mean_fp16(self):
817
        x = torch.ones(65536, device="cuda", dtype=torch.float16)
818
        self.assertEqual(x.mean(), 1)
819

820
        x = torch.ones(65536, device="cuda", dtype=torch.float16)
821
        self.assertEqual(x.mean(dtype=torch.float32), 1)
822

823
    def test_prod_large(self):
824
        # tests global reduction (should_global_reduce = true) in case of non-zero identity element
825
        x = torch.ones(240000, device="cuda", dtype=torch.float32)
826
        self.assertEqual(x.prod(), 1)
827

828
        # test for complex types. Note 240k is divisible by 4
829
        for dtype in [torch.cfloat, torch.cdouble]:
830
            x = torch.ones(240000, device="cuda", dtype=dtype) * (0 + 1j)
831
            self.assertEqual(x.prod(), 1)
832

833
    def test_multinomial_ext(self):
834
        # Test two corner cases from older PyTorch (Issue #4858)
835
        freqs = torch.cuda.FloatTensor(
836
            [
837
                0.0,
838
                0.0,
839
                0.0,
840
                0.0,
841
                0.0,
842
                0.0,
843
                0.0,
844
                0.0,
845
                0.0,
846
                0.03178183361887932,
847
                0.027680952101945877,
848
                0.033176131546497345,
849
                0.046052902936935425,
850
                0.07742464542388916,
851
                0.11543981730937958,
852
                0.14148041605949402,
853
                0.15784293413162231,
854
                0.13180233538150787,
855
                0.08271478116512299,
856
                0.049702685326337814,
857
                0.027557924389839172,
858
                0.018125897273421288,
859
                0.011851548217236996,
860
                0.010252203792333603,
861
                0.007422595750540495,
862
                0.005372154992073774,
863
                0.0045109698548913,
864
                0.0036087757907807827,
865
                0.0035267581697553396,
866
                0.0018864056328311563,
867
                0.0024605290964245796,
868
                0.0022964938543736935,
869
                0.0018453967059031129,
870
                0.0010662291897460818,
871
                0.0009842115687206388,
872
                0.00045109697384759784,
873
                0.0007791675161570311,
874
                0.00020504408166743815,
875
                0.00020504408166743815,
876
                0.00020504408166743815,
877
                0.00012302644609007984,
878
                0.0,
879
                0.00012302644609007984,
880
                4.100881778867915e-05,
881
                0.0,
882
                0.0,
883
                0.0,
884
                0.0,
885
                0.0,
886
                0.0,
887
            ]
888
        )
889

890
        torch.cuda.manual_seed(11042)
891
        sample = torch.multinomial(freqs, 1000, True)
892
        self.assertNotEqual(freqs[sample].min(), 0)
893

894
        p = torch.zeros(3421, 2, device="cuda", dtype=torch.float)
895
        p[:, 1] = 1
896
        torch.cuda.manual_seed(5214)
897
        r = torch.multinomial(p, 1)
898
        self.assertNotEqual(r.min().item(), 0)
899

900
        # test corner case from Issue #13867
901
        torch.cuda.manual_seed(33)
902
        probs = torch.randn(1000000, device="cuda").clamp(min=0) * 3e-5
903
        samples = probs.multinomial(1000000, replacement=True)
904
        self.assertGreater(probs[samples].min().item(), 0)
905

906
    def _spawn_test_multinomial_invalid_probs_cuda(self, probs):
907
        import subprocess
908

909
        try:
910
            p = subprocess.Popen(
911
                [
912
                    sys.executable,
913
                    "-c",
914
                    f"""\
915
import sys
916
import torch
917
from torch import inf, nan
918
try:
919
    with torch.random.fork_rng(devices=[0]):
920
        torch.multinomial(torch.tensor({probs}).to('cuda'), 2, replacement=True)
921
        torch.cuda.synchronize()
922
    sys.exit(-1) # Should not be reached
923
except RuntimeError as e:
924
    sys.exit(-2)
925
""",
926
                ],
927
                stdout=subprocess.PIPE,
928
                stderr=subprocess.PIPE,
929
                universal_newlines=True,
930
            )
931
            out, err = p.communicate(timeout=10)
932
            p.wait(timeout=10)
933
        except subprocess.TimeoutExpired as e:
934
            p.kill()
935
            out, err = p.communicate()
936
        expected_messages = [
937
            "device-side assert triggered",  # CUDA
938
            "Assertion",  # CUDA
939
            "HSA_STATUS_ERROR_EXCEPTION",  # ROCm
940
            "Device-side assertion",  # ROCm
941
        ]
942
        self.assertTrue(any(msg in out or msg in err for msg in expected_messages))
943

944
    @slowTest
945
    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
946
    @unittest.skipIf(
947
        NO_MULTIPROCESSING_SPAWN,
948
        "Disabled for environments that \
949
                     don't support multiprocessing with spawn start method",
950
    )
951
    def test_multinomial_invalid_probs_cuda(self):
952
        self._spawn_test_multinomial_invalid_probs_cuda([1.0, -1.0, 1.0])
953
        self._spawn_test_multinomial_invalid_probs_cuda([1.0, inf, 1.0])
954
        self._spawn_test_multinomial_invalid_probs_cuda([1.0, -inf, 1.0])
955
        self._spawn_test_multinomial_invalid_probs_cuda([1.0, 1.0, nan])
956

957
    @staticmethod
958
    def _mute_init():
959
        os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())
960

961
    def _spawn_method(self, method, arg):
962
        ctx = torch.multiprocessing.get_context("spawn")
963
        with ctx.Pool(1, initializer=self._mute_init) as pool:
964
            errors = pool.map(method, [arg])
965
            for e in errors:
966
                if "device-side assert triggered" not in str(e):
967
                    self.fail(e)
968

969
    @staticmethod
970
    def _test_index_bounds_cuda(idx):
971
        x = torch.arange(10, device="cuda")
972
        try:
973
            y = x[torch.tensor([idx])]
974
            return f"x[torch.tensor([{idx})]={y}"
975
        except RuntimeError as err:
976
            return err
977

978
    @slowTest
979
    @unittest.skipIf(
980
        NO_MULTIPROCESSING_SPAWN,
981
        "Disabled for environments that \
982
                     don't support multiprocessing with spawn start method",
983
    )
984
    @skipIfRocm
985
    def test_index_out_of_bounds_exception_cuda(self):
986
        test_method = TestCuda._test_index_bounds_cuda
987
        # Test in-bound access works fine
988
        self.assertEqual(
989
            test_method(1), "x[torch.tensor([1)]=tensor([1], device='cuda:0')"
990
        )
991
        # Test that indexing out of bounds causes assert
992
        self._spawn_method(test_method, 11)
993

994
    @slowTest
995
    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
996
    @serialTest()
997
    def test_huge_index(self):
998
        src = torch.empty(15000000, 45, device="cuda", dtype=torch.long).random_(
999
            0, 2**22
1000
        )
1001
        idx = torch.randperm(src.shape[0], device="cuda")
1002
        res = src[idx]
1003
        res_cpu = src.cpu()[idx.cpu()]
1004
        self.assertEqual(res.cpu(), res_cpu)
1005

1006
    def test_randint_randomness_for_large_range(self) -> None:
1007
        # For large ranges, randint generation is slightly different. This lead to a subtle bug where some Philox
1008
        # offsets were not calculated correctly, resulting in reused random states.
1009
        # See https://github.com/pytorch/pytorch/issues/125224
1010
        size = 1_000_000
1011
        high = 6_000_000_000  # Keep this above 2**32
1012

1013
        def run(dev: torch.device) -> int:
1014
            # Measure how many unique numbers are generated in 2 consecutive calls to randint. If random states are
1015
            # reused, this will yield fewer unique numbers.
1016
            gen = torch.Generator(device=dev)
1017
            gen.manual_seed(0)
1018
            t1 = torch.randint(
1019
                0, high, [size], device=dev, generator=gen, dtype=torch.int64
1020
            )
1021
            t2 = torch.randint(
1022
                0, high, [size], device=dev, generator=gen, dtype=torch.int64
1023
            )
1024
            return torch.stack([t1, t2]).unique().shape[0]
1025

1026
        # Use CPU as reference. The results should not deviate too much.
1027
        assert abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000
1028

1029
    @parametrize("dtype", [torch.float32, torch.double])
1030
    def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None:
1031
        # Test if random states do not overlap between consecutive rand/randn calls.
1032
        # See https://github.com/pytorch/pytorch/issues/125224
1033

1034
        def run(func, dev: torch.device, dtype: torch.dtype) -> int:
1035
            # Measure how many unique numbers are generated in 2 consecutive calls. If random states are
1036
            # reused, this will yield fewer unique numbers.
1037
            size = 1000000
1038
            gen = torch.Generator(device=dev)
1039
            gen.manual_seed(0)
1040
            t1 = func((size,), device=dev, generator=gen, dtype=dtype)
1041
            t2 = func((size,), device=dev, generator=gen, dtype=dtype)
1042
            return torch.stack([t1, t2]).unique().shape[0]
1043

1044
        # Use CPU as reference. The results should not deviate too much.
1045
        for func in [torch.rand, torch.randn]:
1046
            deviation = abs(
1047
                run(func, torch.device("cuda"), dtype)
1048
                - run(func, torch.device("cpu"), dtype)
1049
            )
1050
            assert deviation < 50_000, deviation
1051

1052
    def test_min_max_inits(self):
1053
        # Testing if THC_reduceAll received the correct index initialization.
1054
        # This affects the result of THC_reduceAll operations at extreme values
1055
        x = torch.cuda.ByteTensor([0])
1056
        y = torch.cuda.ByteTensor([255])
1057
        expected = torch.cuda.LongTensor([0])[0]
1058

1059
        _, v = x.max(dim=0)
1060
        self.assertEqual(v, expected)
1061

1062
        _, v = y.min(dim=0)
1063
        self.assertEqual(v, expected)
1064

1065
    def test_nvtx(self):
1066
        # Just making sure we can see the symbols
1067
        torch.cuda.nvtx.range_push("foo")
1068
        torch.cuda.nvtx.mark("bar")
1069
        torch.cuda.nvtx.range_pop()
1070
        range_handle = torch.cuda.nvtx.range_start("range_start")
1071
        torch.cuda.nvtx.range_end(range_handle)
1072

1073
    def test_bincount_ext(self):
1074
        # ensure CUDA code coverage
1075
        input_size = (100000,)
1076
        w = torch.randn(input_size, dtype=torch.double, device="cuda")
1077
        w_cpu = w.cpu()
1078
        # test shared memory impl
1079
        t = torch.randint(50, input_size, dtype=torch.int8, device="cuda")
1080
        self.assertEqual(t.cpu().bincount(), t.bincount())
1081
        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
1082
        # test global memory impl
1083
        #   see `CUDAHistogramMemoryType` in SummaryOps.cu
1084
        #   50000 * sizeof(int64_t) == 390 KiB, which should exceed smem of any known GPU
1085
        t = torch.randint(50000, input_size, dtype=torch.int64, device="cuda")
1086
        self.assertEqual(t.cpu().bincount(), t.bincount())
1087
        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
1088

1089
        t = torch.zeros([10], dtype=torch.int32, device="cuda")
1090
        # 35488 * 65536 as int32 would cause overflow to negative value
1091
        # giving negative bin offset
1092
        t[0] = 35488
1093
        counted = t.bincount(minlength=65536)
1094
        self.assertEqual(torch.sum(counted), 10)
1095

1096
    def test_tiny_half_norm_(self):
1097
        a = torch.arange(25).cuda().float()
1098
        a /= 100000000
1099
        b = a.half()
1100
        self.assertGreater(b.norm().item(), 0)
1101

1102
    def test_norm_type_conversion(self):
1103
        a = torch.ones(65536).cuda().half()
1104
        self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536)
1105

1106
    def test_cuda_memory_leak_detection_propagates_errors(self):
1107
        with self.assertRaisesRegex(
1108
            RuntimeError, r"The size of tensor a \(3\) must match"
1109
        ):
1110
            with self.assertLeaksNoCudaTensors():
1111
                x = torch.randn(3, 1, device="cuda")
1112
                y = torch.randn(2, 1, device="cuda")
1113
                z = x + y
1114

1115
    @unittest.skipIf(not TEST_MEDIUM_TENSOR, "not enough memory")
1116
    @serialTest()
1117
    def test_cuda_kernel_loop_overflow(self):
1118
        # Issue #24309: In extreme cases, the loop variable could overflow and continue
1119
        # the kernel loop with a negative index, causing a RuntimeError (invalid write):
1120
        x = torch.randn(1, 1, 1, 2**30 + 1, dtype=torch.float16, device="cuda")
1121
        expected = x[0, 0, 0, 2**30]
1122
        y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1123
        torch.cuda.synchronize()
1124
        self.assertEqual(y[0, 0, 0, 2**30], expected)
1125

1126
    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
1127
    @gcIfJetson
1128
    @serialTest()
1129
    def test_cuda_kernel_loop_overflow_large(self):
1130
        # Make sure input.numel() > INT_MAX is handled:
1131
        x = torch.randn(1, 1, 1, 2**31, dtype=torch.float16, device="cuda")
1132
        with self.assertRaisesRegex(RuntimeError, "integer out of range"):
1133
            y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1134

1135
        # Issue #24309: In extreme cases, the loop variable could overflow and continue
1136
        # the kernel loop with a negative index, causing a RuntimeError (invalid write):
1137
        x = torch.randn(1, 1, 1, 2**31 - 1, dtype=torch.float16, device="cuda")
1138
        expected = x[0, 0, 0, 2**31 - 2]
1139
        y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1140
        torch.cuda.synchronize()
1141
        self.assertEqual(y[0, 0, 0, 2**31 - 2], expected)
1142

1143
    # this might create a reference cycle on self...
1144
    def _make_multiply_in_stream(self):
1145
        class MultiplyInStream(torch.autograd.Function):
1146
            @staticmethod
1147
            def forward(ctx, x, val):
1148
                ctx.val = val
1149
                ctx.stream = torch.cuda.current_stream()
1150
                return x * val
1151

1152
            @staticmethod
1153
            def backward(ctx, grad):
1154
                self.assertEqual(torch.cuda.current_stream(), ctx.stream)
1155
                # delays the operation in the background stream
1156
                torch.cuda._sleep(1000 * 5000)
1157
                return grad * ctx.val, None
1158

1159
        return MultiplyInStream
1160

1161
    @skipCUDANonDefaultStreamIf(True)
1162
    def test_streaming_backwards_sync(self):
1163
        default_stream = torch.cuda.current_stream()
1164
        stream = torch.cuda.Stream()
1165

1166
        MultiplyInStream = self._make_multiply_in_stream()
1167

1168
        # Tests using grads outside the backward() stream context
1169
        # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html
1170
        x = torch.randn(5, 5, device="cuda", requires_grad=True)
1171
        with torch.cuda.stream(stream):
1172
            stream.wait_stream(default_stream)
1173
            output = MultiplyInStream.apply(x, 2)
1174
            output.sum().backward()
1175
        # sync needed
1176
        default_stream.wait_stream(stream)
1177
        self.assertEqual(x.grad, torch.ones_like(x) * 2)
1178
        self.assertEqual(torch.cuda.current_stream(), default_stream)
1179

1180
        # Tests that using grads in the same stream context as backward()
1181
        # is safe regardless what streams bwd ops ran on
1182
        bwd_ambient_stream = torch.cuda.Stream()
1183
        x = torch.randn(5, 5, device="cuda", requires_grad=True)
1184
        with torch.cuda.stream(stream):
1185
            stream.wait_stream(default_stream)
1186
            output = MultiplyInStream.apply(x, 3)
1187
        with torch.cuda.stream(bwd_ambient_stream):
1188
            bwd_ambient_stream.wait_stream(stream)
1189
            output.sum().backward()
1190
            # x was first used on "stream" so its AccumulateGrad leaf should run on "stream".
1191
            # The end of backward() should have synced "bwd_ambient_stream" with "stream"
1192
            # so it should be safe to use x.grad here without any syncs.
1193
            self.assertEqual(x.grad, torch.ones_like(x) * 3)
1194
            self.assertEqual(torch.cuda.current_stream(), bwd_ambient_stream)
1195

1196
    # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
1197
    @skipIfRocm(msg="flakey on ROCm https://github.com/pytorch/pytorch/issues/53190")
1198
    def test_streaming_backwards_multiple_streams(self):
1199
        MultiplyInStream = self._make_multiply_in_stream()
1200

1201
        class StreamModel(torch.nn.Module):
1202
            def __init__(self) -> None:
1203
                super().__init__()
1204
                self.event = torch.cuda.Event()
1205
                self.stream0 = torch.cuda.Stream()
1206
                self.stream1 = torch.cuda.Stream()
1207

1208
            def forward(self, x, x_first_use_on_ambient):
1209
                if x_first_use_on_ambient:
1210
                    x0 = x.clone()
1211
                self.stream0.wait_stream(torch.cuda.current_stream())
1212
                self.stream1.wait_stream(torch.cuda.current_stream())
1213
                with torch.cuda.stream(self.stream0):
1214
                    if not x_first_use_on_ambient:
1215
                        x0 = x.clone()
1216
                    y0 = MultiplyInStream.apply(x0, 2)
1217
                    self.event.record(stream=torch.cuda.current_stream())
1218

1219
                with torch.cuda.stream(self.stream1):
1220
                    y1 = MultiplyInStream.apply(x, 3)
1221
                    self.stream1.wait_event(self.event)
1222
                    return y0 + y1
1223

1224
        stream = torch.cuda.Stream()
1225

1226
        for x_first_use_on_ambient in (True, False):
1227
            # the out_of_place=False, iters=1 case stresses if proper syncs are inserted
1228
            # when grads are initially None and stolen by backward ops.
1229
            for out_of_place, iters in ((True, 1), (False, 1), (False, 5)):
1230
                with torch.cuda.stream(stream):
1231
                    x = torch.randn(5, 5, device="cuda", requires_grad=True)
1232
                    model = StreamModel().cuda()
1233
                    x.register_hook(
1234
                        lambda grad: self.assertEqual(
1235
                            torch.cuda.current_stream(),
1236
                            stream if x_first_use_on_ambient else model.stream0,
1237
                        )
1238
                    )
1239
                    for p in model.parameters():
1240
                        self.assertTrue(p.grad is None)
1241
                    for i in range(iters):
1242
                        loss = model(x, x_first_use_on_ambient).sum()
1243
                        if out_of_place:
1244
                            x_grad = torch.autograd.grad((loss,), (x,))[0]
1245
                        else:
1246
                            loss.backward()
1247
                # See "Stream semantics of backward passes" on https://pytorch.org/docs/stable/notes/cuda.html
1248
                torch.cuda.current_stream().wait_stream(stream)
1249

1250
                if out_of_place:
1251
                    self.assertEqual(x_grad, torch.ones_like(x) * 5 * iters)
1252
                else:
1253
                    self.assertEqual(x.grad, torch.ones_like(x) * 5 * iters)
1254

1255
    def test_streaming_backwards_sync_graph_root(self):
1256
        # This function tests if bwd ops running on a side stream properly sync with the GraphRoot.
1257
        # The potential bug it targets is a race condition. The test uses multiple trials and
1258
        # torch.cuda._sleep such that if the race condition exists, the test will almost certainly fail,
1259
        # but there's a chance it may spuriously pass. Passing does not guarantee the backend is bug-free,
1260
        # but failure does guarantee there is a bug.
1261
        fwd_bwd_op_stream = torch.cuda.Stream()
1262
        bwd_ambient_stream = torch.cuda.Stream()
1263
        # We need these streams to be different otherwise the test is meaningless.
1264
        self.assertTrue(fwd_bwd_op_stream != bwd_ambient_stream)
1265

1266
        size = int(1e3)
1267

1268
        a = torch.full((size,), 2.0, device="cuda", requires_grad=True)
1269
        b = torch.full((size,), 3.0, device="cuda", requires_grad=True)
1270

1271
        # I don't think we need any manual record_streams below.
1272
        # a and b remain in scope for the entire test.
1273
        # c and grad remain in scope for each iteration, and there's a full sync between iterations.
1274
        for trial in range(5):
1275
            torch.cuda.synchronize()
1276
            a.grad = b.grad = None
1277
            with torch.cuda.stream(fwd_bwd_op_stream):
1278
                c = a * b
1279

1280
            with torch.cuda.stream(bwd_ambient_stream):
1281
                torch.cuda.synchronize()
1282
                # Long-running dummy kernel on bwd_ambient_stream delays filling of grad
1283
                torch.cuda._sleep(int(50 * get_cycles_per_ms()))
1284
                # Fills grad on bwd_ambient_stream
1285
                grad = torch.full((size,), float(trial + 1), device="cuda")
1286

1287
                # Bwd ops still run on fwd_bwd_ops_stream, so the following will likely fail if
1288
                # bwd ops don't sync with bwd_ambient_stream before consuming grad.
1289
                torch.autograd.backward(tensors=c, grad_tensors=grad)
1290

1291
                # See https://github.com/pytorch/pytorch/issues/47028
1292
                # assertEquals below run on bwd_ambient_stream, so this test may also fail
1293
                # if backward() fails to sync with bwd_ambient_stream at the end.
1294
                # Synchronizing here works around the issue until a proper fix can be made.
1295
                torch.cuda.synchronize()
1296
                with torch.no_grad():
1297
                    self.assertEqual(a.grad, grad * b)
1298
                    self.assertEqual(b.grad, grad * a)
1299

1300
    def test_streaming_backwards_callback(self):
1301
        # Tests if autograd callbacks sync properly with respect to leaf streams and
1302
        # the user-facing stream surrounding backward(). If it fails, first suspect is
1303
        # sync logic where  "final_callbacks_" are called in torch/csrc/autograd/engine.cpp
1304
        MultiplyInStream = self._make_multiply_in_stream()
1305

1306
        size = int(1e3)
1307
        a = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True)
1308
        b = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True)
1309

1310
        s0 = torch.cuda.Stream()
1311
        s1 = torch.cuda.Stream()
1312
        s2 = torch.cuda.Stream()
1313

1314
        stash = []
1315

1316
        # sets up a nontrivial structure of leaf streams
1317
        s0.wait_stream(torch.cuda.current_stream())
1318
        with torch.cuda.stream(s0):
1319
            c = MultiplyInStream.apply(a, 2)
1320

1321
        s1.wait_stream(torch.cuda.current_stream())
1322
        with torch.cuda.stream(s1):
1323
            d = MultiplyInStream.apply(b, 3)
1324
            s1.wait_stream(s0)
1325
            e = c * d
1326

1327
            def clone_leaf_grads():
1328
                stash.append(a.grad.clone())
1329
                stash.append(b.grad.clone())
1330

1331
            # Use a hook on e to install the callback
1332
            e.register_hook(
1333
                lambda grad: torch.autograd.Variable._execution_engine.queue_callback(
1334
                    clone_leaf_grads
1335
                )
1336
            )
1337

1338
        s2.wait_stream(s1)
1339
        with torch.cuda.stream(s2):
1340
            e.sum().backward()
1341
            # The autograd engine should sync s2 with all leaf streams then run the callback clone_leaf_grads on s2.
1342
            # If those things happened properly, checking the values of the cloned grads on s2 should be safe:
1343
            self.assertEqual(stash[0], torch.full_like(a, 6))
1344
            self.assertEqual(stash[1], torch.full_like(a, 6))
1345

1346
    @unittest.skipIf(
1347
        TEST_WITH_ROCM,
1348
        "In ROCm, kernel asserts are disabled due to performance overhead",
1349
    )
1350
    def test_fixed_cuda_assert_async(self):
1351
        with self.assertRaisesRegex(
1352
            RuntimeError, "Boolean value of Tensor with no values is ambiguous"
1353
        ):
1354
            torch._assert_async(torch.tensor([], device="cuda"))
1355
        with self.assertRaisesRegex(
1356
            RuntimeError,
1357
            "Boolean value of Tensor with more than one value is ambiguous",
1358
        ):
1359
            torch._assert_async(torch.tensor([0, 0], device="cuda"))
1360

1361
        torch._assert_async(torch.tensor(1, device="cuda"))
1362
        torch._assert_async(torch.tensor(0.1, device="cuda"))
1363
        torch._assert_async(torch.tensor(-0.1, device="cuda"))
1364
        torch._assert_async(torch.tensor(True, device="cuda"))
1365
        torch._assert_async(torch.tensor(0 + 0.1j, device="cuda"))
1366

1367
        fail_stmts = [
1368
            "torch._assert_async(torch.tensor(0, device='cuda'))",
1369
            "torch._assert_async(torch.tensor(0.0, device='cuda'))",
1370
            "torch._assert_async(torch.tensor(False, device='cuda'))",
1371
            "torch._assert_async(torch.tensor(0 + 0j, device='cuda'))",
1372
        ]
1373

1374
        import subprocess
1375

1376
        for stmt in fail_stmts:
1377
            with self.subTest(stmt=stmt):
1378
                r = subprocess.call(
1379
                    [
1380
                        sys.executable,
1381
                        "-c",
1382
                        f"""\
1383
import torch
1384

1385
{stmt}
1386
torch.cuda.synchronize()
1387
""",
1388
                    ]
1389
                )
1390
                self.assertTrue(r != 0)
1391

1392
    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "FAIL")
1393
    def test_cublas_multiple_threads_same_device(self):
1394
        # Note, these parameters should be very carefully tuned
1395
        # Too small number makes it hard for the racing condition
1396
        # to happen, while too large number sometimes cause hang
1397
        size = 1024
1398
        num_threads = 2
1399
        trials = 3
1400
        test_iters = 100
1401

1402
        weight = torch.ones((size, size), device="cuda")
1403
        results = {}
1404
        barrier = threading.Barrier(num_threads)
1405

1406
        def _worker(t):
1407
            my_stream = torch.cuda.Stream()
1408
            # Hard sync so we don't need to worry about creating and using tensors
1409
            # across streams or the fact that default streams are thread-local.
1410
            # Those issues are not the target of this test.
1411
            torch.cuda.synchronize()
1412
            # Line up threads to increase likelihood of race conditions.
1413
            barrier.wait()
1414
            with torch.cuda.stream(my_stream):
1415
                for i in range(test_iters):
1416
                    # If all threads are sharing the same cublas handle,
1417
                    # the following sequence may occur:
1418
                    # thread 0 calls cublasSetStream()
1419
                    # thread 1 calls cublasSetStream()
1420
                    # thread 0 launches its raw gemm, which it thinks is in
1421
                    #          its own stream, but is actually in thread 1's stream.
1422
                    # thread 0 enqueues its div_, which IS is its own stream,
1423
                    #          but actually now races with its gemm.
1424
                    results[t] = torch.mm(results[t], weight)
1425
                    results[t].div_(float(size))
1426
            torch.cuda.synchronize()
1427

1428
        for _ in range(trials):
1429
            for t in range(num_threads):
1430
                results[t] = torch.ones((size, size), device="cuda")
1431

1432
            threads = [
1433
                threading.Thread(target=_worker, args=(t,)) for t in range(num_threads)
1434
            ]
1435

1436
            for thread in threads:
1437
                thread.start()
1438
            for thread in threads:
1439
                thread.join()
1440

1441
            for t in range(num_threads):
1442
                self.assertEqual(results[t].sum().item(), size * size)
1443

1444
    # Test is flaky on Windows (https://github.com/pytorch/pytorch/issues/57401)
1445
    @unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows (see issue 57401)")
1446
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1447
    @skipIfRocm
1448
    def test_cudnn_multiple_threads_same_device(self):
1449
        # This function is intended to test the lazy creation and reuse of per-thread
1450
        # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp.
1451
        # Failure here likely indicates something wrong with that logic.
1452
        weight = torch.ones((1, 1, 2, 2), device="cuda")
1453

1454
        results = {}
1455

1456
        num_threads = 2
1457
        trials = 3
1458
        test_iters = 1000
1459
        barrier = threading.Barrier(num_threads)
1460

1461
        with torch.backends.cudnn.flags(enabled=True):
1462

1463
            def _worker(t):
1464
                my_stream = torch.cuda.Stream()
1465
                # Hard sync so we don't need to worry about creating and using tensors
1466
                # across streams or the fact that default streams are thread-local.
1467
                # Those issues are not the target of this test.
1468
                torch.cuda.synchronize()
1469
                # Line up threads to increase likelihood of race conditions.
1470
                barrier.wait()
1471
                with torch.cuda.stream(my_stream):
1472
                    for _ in range(test_iters):
1473
                        # If all threads are sharing the same cudnn handle,
1474
                        # the following sequence may occur:
1475
                        # thread 0 calls setCuDNNStreamToCurrent()
1476
                        # thread 1 calls setCuDNNStreamToCurrent()
1477
                        # thread 0 launches its raw convolution, which it thinks is in
1478
                        #          its own stream, but is actually in thread 1's stream.
1479
                        # thread 0 enqueues its div_, which IS is its own stream,
1480
                        #          but now races with its convolution.
1481
                        results[t] = torch.nn.functional.conv2d(
1482
                            results[t], weight, padding=0
1483
                        )
1484
                        results[t].div_(4.0)
1485
                torch.cuda.synchronize()
1486

1487
            for _ in range(trials):
1488
                for t in range(num_threads):
1489
                    results[t] = torch.ones((1, 1, 2048, 2048), device="cuda")
1490

1491
                threads = [
1492
                    threading.Thread(target=_worker, args=(t,))
1493
                    for t in range(num_threads)
1494
                ]
1495

1496
                for thread in threads:
1497
                    thread.start()
1498
                for thread in threads:
1499
                    thread.join()
1500

1501
                for t in range(num_threads):
1502
                    self.assertEqual(
1503
                        results[t].sum().item(),
1504
                        (2048 - test_iters) * (2048 - test_iters),
1505
                    )
1506

1507
    def test_cusparse_multiple_threads_same_device(self):
1508
        size = 1024
1509
        num_threads = 2
1510
        trials = 3
1511
        test_iters = 500
1512

1513
        def ones_sparse(size):
1514
            a = torch.arange(size, device="cuda")
1515
            indices = torch.cartesian_prod(a, a).t()
1516
            values = torch.ones(size * size, device="cuda")
1517
            return torch.sparse_coo_tensor(indices, values)
1518

1519
        weight = ones_sparse(size)
1520
        results = {}
1521
        barrier = threading.Barrier(num_threads)
1522

1523
        def _worker(t):
1524
            my_stream = torch.cuda.Stream()
1525
            # Hard sync so we don't need to worry about creating and using tensors
1526
            # across streams or the fact that default streams are thread-local.
1527
            # Those issues are not the target of this test.
1528
            torch.cuda.synchronize()
1529
            # Line up threads to increase likelihood of race conditions.
1530
            barrier.wait()
1531
            with torch.cuda.stream(my_stream):
1532
                for i in range(test_iters):
1533
                    # If all threads are sharing the same cublas handle,
1534
                    # the following sequence may occur:
1535
                    # thread 0 calls cublasSetStream()
1536
                    # thread 1 calls cublasSetStream()
1537
                    # thread 0 launches its raw gemm, which it thinks is in
1538
                    #          its own stream, but is actually in thread 1's stream.
1539
                    # thread 0 enqueues its div_, which IS is its own stream,
1540
                    #          but actually now races with its gemm.
1541
                    results[t] = weight.mm(results[t])
1542
                    results[t].div_(float(size))
1543
            torch.cuda.synchronize()
1544

1545
        for _ in range(trials):
1546
            for t in range(num_threads):
1547
                results[t] = torch.ones((size, size), device="cuda")
1548

1549
            threads = [
1550
                threading.Thread(target=_worker, args=(t,)) for t in range(num_threads)
1551
            ]
1552

1553
            for thread in threads:
1554
                thread.start()
1555
            for thread in threads:
1556
                thread.join()
1557

1558
            for t in range(num_threads):
1559
                self.assertEqual(results[t].sum().item(), size * size)
1560

1561
    def _run_autocast_outofplace(
1562
        self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
1563
    ):
1564
        # helper to cast args
1565
        def cast(val, to_type):
1566
            if isinstance(val, torch.Tensor):
1567
                return val.to(to_type) if val.is_floating_point() else val
1568
            elif isinstance(val, collections.abc.Iterable):
1569
                return type(val)(cast(v, to_type) for v in val)
1570
            else:
1571
                return val
1572

1573
        if add_kwargs is None:
1574
            add_kwargs = {}
1575
        fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
1576
        self.assertFalse(torch.is_autocast_enabled())
1577
        with torch.autocast("cuda", dtype=fast_dtype):
1578
            self.assertTrue(torch.is_autocast_enabled())
1579

1580
            out_type = out_type if out_type is not None else run_as_type
1581
            output = output_method = None
1582

1583
            # Try module.* variant, if requested:
1584
            if module is not None and hasattr(module, op):
1585
                output = getattr(module, op)(*args, **add_kwargs)
1586
                if isinstance(output, torch.Tensor):
1587
                    self.assertTrue(
1588
                        out_type == output.dtype,
1589
                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
1590
                    )
1591

1592
            # Try Tensor.* variant:
1593
            if hasattr(torch.Tensor, op):
1594
                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
1595
                if isinstance(output_method, torch.Tensor):
1596
                    self.assertTrue(
1597
                        out_type == output_method.dtype,
1598
                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
1599
                    )
1600

1601
            self.assertTrue(
1602
                (output is not None) or (output_method is not None),
1603
                f"{op} not found as an attribute on either Tensor or the requested module {module}",
1604
            )
1605

1606
            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
1607
            # For example, lstm_cell returns a tuple and equal returns bool.
1608
            def compare(first, second):
1609
                if isinstance(first, torch.Tensor):
1610
                    return torch.equal(first, second)
1611
                elif isinstance(first, collections.abc.Iterable):
1612
                    return all(compare(f, s) for f, s in zip(first, second))
1613
                else:
1614
                    return first == second
1615

1616
            # If both torch.* and Tensor.* variants were found, check outputs are identical
1617
            if (output is not None) and (output_method is not None):
1618
                self.assertTrue(type(output) == type(output_method))
1619
                comparison = compare(output, output_method)
1620
                self.assertTrue(
1621
                    comparison, f"torch.{op} result did not match Tensor.{op} result"
1622
                )
1623

1624
            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
1625
            # as the C++-side autocasting, and should be bitwise accurate.
1626
            output_to_compare = output if output is not None else output_method
1627
            with torch.autocast("cuda", enabled=False):
1628
                self.assertFalse(torch.is_autocast_enabled())
1629

1630
                if module is not None and hasattr(module, op):
1631
                    control = getattr(module, op)(
1632
                        *cast(args, run_as_type), **add_kwargs
1633
                    )
1634
                else:
1635
                    control = getattr(args[0].to(run_as_type), op)(
1636
                        *cast(args[1:], run_as_type), **add_kwargs
1637
                    )
1638
                self.assertTrue(type(output_to_compare) == type(control))
1639
                comparison = compare(output_to_compare, control)
1640
                self.assertTrue(comparison, f"torch.{op} result did not match control")
1641
            self.assertTrue(torch.is_autocast_enabled())
1642
        self.assertFalse(torch.is_autocast_enabled())
1643

1644
    def args_maybe_kwargs(self, op_with_args):
1645
        if len(op_with_args) == 2:
1646
            return op_with_args[0], op_with_args[1], {}
1647
        else:
1648
            return op_with_args[0], op_with_args[1], op_with_args[2]
1649

1650
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1651
    def test_autocast_torch_fp16(self):
1652
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1653
            for op_with_args in self.autocast_lists.torch_fp16:
1654
                skip_test = False
1655
                op, args = op_with_args[0], op_with_args[1]
1656
                if len(op_with_args) == 3:
1657
                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
1658
                if not skip_test:
1659
                    self._run_autocast_outofplace(op, args, torch.float16)
1660

1661
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1662
    def test_autocast_torch_bf16(self):
1663
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1664
            for op_with_args in self.autocast_lists.torch_fp16:
1665
                skip_test = False
1666
                op, args = op_with_args[0], op_with_args[1]
1667
                if len(op_with_args) == 3:
1668
                    skip_test = op_with_args[2]  # TEST_WITH_ROCM
1669
                should_error_from_cudnn = "cudnn" in op and (
1670
                    "TORCH_CUDNN_V8_API_DISABLED" in os.environ
1671
                    and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"])
1672
                    or torch.cuda.get_device_capability() < (8, 0)
1673
                )
1674
                should_error_from_not_implemented = should_error_from_cudnn
1675
                if not skip_test:
1676
                    if should_error_from_not_implemented:
1677
                        with self.assertRaises(
1678
                            RuntimeError,
1679
                            msg=str(op) + " should not be supported for bfloat16!",
1680
                        ):
1681
                            self._run_autocast_outofplace(op, args, torch.bfloat16)
1682
                    else:
1683
                        if torch.cuda.is_bf16_supported():
1684
                            self._run_autocast_outofplace(op, args, torch.bfloat16)
1685
                        else:
1686
                            with self.assertRaisesRegex(
1687
                                RuntimeError, "Device does not support bfloat16"
1688
                            ):
1689
                                self._run_autocast_outofplace(op, args, torch.bfloat16)
1690

1691
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1692
    def test_autocast_torch_fp32(self):
1693
        for op_with_args in self.autocast_lists.torch_fp32:
1694
            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
1695
            self._run_autocast_outofplace(
1696
                op, args, torch.float32, add_kwargs=maybe_kwargs
1697
            )
1698

1699
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1700
    def test_autocast_torch_need_autocast_promote(self):
1701
        for op, args in self.autocast_lists.torch_need_autocast_promote:
1702
            self._run_autocast_outofplace(op, args, torch.float32)
1703

1704
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1705
    def test_autocast_torch_expect_builtin_promote(self):
1706
        for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
1707
            self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
1708

1709
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1710
    def test_autocast_nn_fp16(self):
1711
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1712
            for op, args in self.autocast_lists.nn_fp16:
1713
                self._run_autocast_outofplace(
1714
                    op, args, torch.float16, module=torch._C._nn
1715
                )
1716

1717
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1718
    def test_autocast_nn_bf16(self):
1719
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1720
            for op, args in self.autocast_lists.nn_fp16:
1721
                if torch.cuda.is_bf16_supported():
1722
                    self._run_autocast_outofplace(
1723
                        op, args, torch.bfloat16, module=torch._C._nn
1724
                    )
1725
                else:
1726
                    with self.assertRaisesRegex(
1727
                        RuntimeError, "Device does not support bfloat16"
1728
                    ):
1729
                        self._run_autocast_outofplace(
1730
                            op, args, torch.bfloat16, module=torch._C._nn
1731
                        )
1732

1733
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1734
    def test_autocast_nn_fp32(self):
1735
        for op, args in self.autocast_lists.nn_fp32:
1736
            self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn)
1737

1738
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1739
    def test_autocast_linalg_fp16(self):
1740
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1741
            for op, args in self.autocast_lists.linalg_fp16:
1742
                self._run_autocast_outofplace(
1743
                    op, args, torch.float16, module=torch._C._linalg
1744
                )
1745

1746
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1747
    def test_autocast_methods_fp16(self):
1748
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1749
            for op, args in self.autocast_lists.methods_fp16:
1750
                self._run_autocast_outofplace(op, args, torch.float16, module=None)
1751

1752
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1753
    def test_autocast_methods_fp32(self):
1754
        for op, args in self.autocast_lists.methods_fp32:
1755
            self._run_autocast_outofplace(op, args, torch.float32, module=None)
1756

1757
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1758
    def test_autocast_methods_expect_builtin_promote(self):
1759
        for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
1760
            self._run_autocast_outofplace(
1761
                op, args, torch.float32, module=None, out_type=out_type
1762
            )
1763

1764
    def test_autocast_banned(self):
1765
        with torch.autocast("cuda"):
1766
            for op, args, module in self.autocast_lists.banned:
1767
                with self.assertRaises(RuntimeError):
1768
                    getattr(module, op)(*args)
1769

1770
    def test_autocast_ignored_types(self):
1771
        with torch.autocast("cuda"):
1772
            for ignore_type in (torch.double, torch.int32):
1773
                a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
1774
                b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
1775
                c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0")
1776

1777
                # Tests if CastPolicy::fp16 ops ignore double and int
1778
                # Currently, no ops belonging to this policy support integer inputs.
1779
                if ignore_type is torch.double:
1780
                    with self.assertRaises(RuntimeError):
1781
                        torch.mm(a_ignore, c_16)
1782
                    with torch.autocast("cuda", enabled=False):
1783
                        type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
1784
                    self.assertTrue(
1785
                        torch.mm(a_ignore, b_ignore).dtype is type_no_autocast
1786
                    )
1787

1788
                # Tests if CastPolicy::fp32 ops ignore double and int
1789
                with torch.autocast("cuda", enabled=False):
1790
                    type_no_autocast = torch.pow(a_ignore, 2.0).dtype
1791
                self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)
1792

1793
                # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int
1794
                with torch.autocast("cuda", enabled=False):
1795
                    type_no_autocast = torch.sum(a_ignore).dtype
1796
                self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)
1797

1798
                # Tests if CastPolicy::fp32_append_dtype ops ignore double and int
1799
                # Currently, no ops belonging to this policy support integer inputs.
1800
                if ignore_type is torch.double:
1801
                    with torch.autocast("cuda", enabled=False):
1802
                        type_no_autocast = torch.norm(a_ignore).dtype
1803
                    self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)
1804

1805
    def test_autocast_custom_enabled(self):
1806
        class MyMM(torch.autograd.Function):
1807
            @staticmethod
1808
            @torch.amp.custom_fwd(device_type="cuda")
1809
            def forward(ctx, a, b):
1810
                self.assertTrue(a.dtype is torch.float32)
1811
                self.assertTrue(b.dtype is torch.float32)
1812
                self.assertTrue(torch.is_autocast_enabled())
1813
                ctx.save_for_backward(a, b)
1814
                return a.mm(b)
1815

1816
            @staticmethod
1817
            @torch.amp.custom_bwd(device_type="cuda")
1818
            def backward(ctx, grad):
1819
                self.assertTrue(torch.is_autocast_enabled())
1820
                a, b = ctx.saved_tensors
1821
                a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad)
1822
                self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype)
1823
                return a_grad, b_grad
1824

1825
        mymm = MyMM.apply
1826

1827
        x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
1828
        y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
1829

1830
        dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
1831
        for dtype in dtypes:
1832
            with torch.cuda.amp.autocast(dtype=dtype):
1833
                output = mymm(x, y)
1834
                self.assertTrue(output.dtype is dtype)
1835
                loss = output.sum()
1836
            loss.backward()
1837

1838
    def test_autocast_custom_cast_inputs(self):
1839
        class MyMM(torch.autograd.Function):
1840
            @staticmethod
1841
            @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
1842
            def forward(ctx, a, container, expect_type):
1843
                b = container[1][0]
1844
                self.assertTrue(a.dtype is expect_type)
1845
                self.assertTrue(b.dtype is expect_type)
1846
                self.assertFalse(torch.is_autocast_enabled())
1847
                ctx.save_for_backward(a, b)
1848
                return a.mm(b)
1849

1850
            @staticmethod
1851
            @torch.amp.custom_bwd(device_type="cuda")
1852
            def backward(ctx, grad):
1853
                self.assertFalse(torch.is_autocast_enabled())
1854
                a, b = ctx.saved_tensors
1855
                return grad.mm(b.t()), None, None
1856

1857
        mymm = MyMM.apply
1858

1859
        x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
1860
        # Puts one input tensor in a nested container.  y's contained Tensor won't receive a gradient,
1861
        # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments.
1862
        # Sets requires_grad=False explicitly so we don't lie about expecting a gradient.
1863
        y = (
1864
            0,
1865
            {
1866
                0: torch.randn(
1867
                    (8, 8), device="cuda", dtype=torch.float16, requires_grad=False
1868
                )
1869
            },
1870
        )
1871

1872
        with torch.autocast("cuda"):
1873
            output = mymm(x, y, torch.float32)
1874
            self.assertTrue(output.dtype is torch.float32)
1875
            loss = output.sum()
1876
        loss.backward()
1877

1878
        # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region.
1879
        output = mymm(x, y, torch.float16)
1880
        self.assertTrue(output.dtype is torch.float16)
1881
        loss = output.sum()
1882
        loss.backward()
1883

1884
    def test_autocast_custom_deprecated_warning(self):
1885
        with warnings.catch_warnings(record=True) as w:
1886

1887
            class MyMM(torch.autograd.Function):
1888
                @staticmethod
1889
                @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
1890
                def forward(ctx, x, y):
1891
                    ctx.save_for_backward(x, y)
1892
                    self.assertFalse(torch.is_autocast_enabled())
1893
                    return x + y
1894

1895
                @staticmethod
1896
                @torch.cuda.amp.custom_bwd
1897
                def backward(ctx, grad):
1898
                    _, _ = ctx.saved_tensors
1899
                    self.assertFalse(torch.is_autocast_enabled())
1900
                    return grad, grad
1901

1902
        self.assertRegex(
1903
            str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
1904
        )
1905
        self.assertRegex(
1906
            str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
1907
        )
1908

1909
        mymm = MyMM.apply
1910
        x = torch.randn(3, 3, requires_grad=True)
1911
        y = torch.randn(3, 3, requires_grad=True)
1912
        with torch.amp.autocast("cuda"):
1913
            output = mymm(x, y)
1914
            loss = output.sum()
1915
        loss.backward()
1916

1917
    def test_autocast_cat_jit(self):
1918
        # Reported at https://github.com/pytorch/pytorch/issues/38958
1919

1920
        class Model(torch.nn.Module):
1921
            def forward(self):
1922
                a = torch.randn(1)
1923
                b = torch.randn(1)
1924
                c = torch.cat((a, b), 0)
1925
                d = torch.stack([c, c], 0)
1926
                return d
1927

1928
        # The JIT here doesn't really matter, we just need to call
1929
        # cat via the boxed API
1930
        model = Model()
1931
        model_jit_script = torch.jit.script(model)
1932

1933
        with torch.autocast("cuda", enabled=True):
1934
            model()
1935
            model_jit_script()
1936

1937
    # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened)
1938
    # so they get a dedicated test.
1939
    # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100).
1940
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1941
    def test_autocast_rnn(self):
1942
        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1943
            # seq, batch, features, hidden size
1944
            clses = ("RNN", "GRU", "LSTM")
1945
            T, B, F, H = 3, 4, 5, 6
1946
            dtypes = (torch.float16, torch.float32)
1947
            input_layouts = ("seq_first", "batch_first", "packed")
1948

1949
            for (
1950
                cls,
1951
                num_layers,
1952
                bias,
1953
                input_layout,
1954
                bidirectional,
1955
                try_nonpreflattened_weights,
1956
                input_dtype,
1957
                hidden_dtype,
1958
                weight_dtype,
1959
            ) in product(
1960
                clses,
1961
                (1, 2),
1962
                (True, False),
1963
                input_layouts,
1964
                (True, False),
1965
                (True, False),
1966
                dtypes,
1967
                dtypes,
1968
                dtypes,
1969
            ):
1970
                if input_layout == "seq_first":
1971
                    batch_first = False
1972
                    x = torch.randn((T, B, F), device="cuda", dtype=input_dtype)
1973
                elif input_layout == "batch_first":
1974
                    batch_first = True
1975
                    x = torch.randn((B, T, F), device="cuda", dtype=input_dtype)
1976
                elif input_layout == "packed":
1977
                    batch_first = False
1978
                    x = torch.nn.utils.rnn.pack_padded_sequence(
1979
                        torch.randn((T, B, F), device="cuda", dtype=input_dtype),
1980
                        lengths=(3, 2, 1, 3),
1981
                        enforce_sorted=False,
1982
                    )
1983

1984
                rnn = (
1985
                    getattr(torch.nn, cls)(
1986
                        F,
1987
                        H,
1988
                        num_layers=num_layers,
1989
                        bidirectional=bidirectional,
1990
                        bias=bias,
1991
                        batch_first=batch_first,
1992
                    )
1993
                    .cuda()
1994
                    .to(dtype=weight_dtype)
1995
                )
1996

1997
                if try_nonpreflattened_weights:
1998
                    for p in rnn.parameters():
1999
                        with torch.no_grad():
2000
                            p.set_(p.clone())
2001

2002
                h = torch.randn(
2003
                    (num_layers * (2 if bidirectional else 1), B, H),
2004
                    device="cuda",
2005
                    dtype=hidden_dtype,
2006
                )
2007
                if cls == "LSTM":
2008
                    c = torch.randn(
2009
                        (num_layers * (2 if bidirectional else 1), B, H),
2010
                        device="cuda",
2011
                        dtype=hidden_dtype,
2012
                    )
2013
                    h = (h, c)
2014

2015
                with torch.autocast("cuda"):
2016
                    out, h_out = rnn(x, h)
2017
                out = out.data if input_layout == "packed" else out
2018
                self.assertEqual(out.dtype, torch.float16)
2019
                # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed.  This check can't guarantee
2020
                # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
2021
                # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
2022
                self.assertEqual(
2023
                    out.grad_fn.name(),
2024
                    "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0",
2025
                )
2026
                out.sum().backward()
2027
                grads = [p.grad.clone() for p in rnn.parameters()]
2028

2029
                rnn.zero_grad()
2030

2031
                if cls == "LSTM":
2032
                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
2033
                        x.half(), (h[0].half(), h[1].half())
2034
                    )
2035
                else:
2036
                    out_control, h_out_control = rnn.to(dtype=torch.float16)(
2037
                        x.half(), h.half()
2038
                    )
2039
                out_control = (
2040
                    out_control.data if input_layout == "packed" else out_control
2041
                )
2042
                out_control.sum().backward()
2043
                grads_control = [p.grad.clone() for p in rnn.parameters()]
2044

2045
                # Compares with default tolerances, even for FP16 execution.  Barring nondeterminism,
2046
                # autocast and control results should be bitwise identical.
2047
                self.assertEqual(out, out_control)
2048

2049
                if cls == "LSTM":
2050
                    self.assertTrue(
2051
                        h_out[0].dtype is torch.float16
2052
                        and h_out[1].dtype is torch.float16
2053
                    )
2054
                    self.assertEqual(h_out[0], h_out_control[0])
2055
                    self.assertEqual(h_out[1], h_out_control[1])
2056
                else:
2057
                    self.assertEqual(h_out.dtype, torch.float16)
2058
                    self.assertEqual(h_out, h_out_control)
2059
                for grad, grad_control in zip(grads, grads_control):
2060
                    self.assertEqual(grad.half(), grad_control)
2061

2062
    def test_autocast_cache_leak(self):
2063
        # Reported at https://github.com/pytorch/pytorch/issues/48049
2064
        # Test is used to check, if autocast recaches the same parameters
2065
        # when executed in a `torch.no_grad()` block.
2066

2067
        linear = torch.nn.Linear(10, 10).to("cuda")
2068
        data = torch.randn(1, 10, device="cuda")
2069

2070
        with torch.autocast("cuda"):
2071
            with torch.no_grad():
2072
                out = linear(data)
2073
                first_iter_mem = torch.cuda.memory_allocated()
2074
                for _ in range(3):
2075
                    out = linear(data)
2076
                self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
2077

2078
    def test_autocast_checkpointing(self):
2079
        model = torch.nn.Sequential(
2080
            torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
2081
        ).cuda()
2082
        input = torch.rand(
2083
            (8, 8), device="cuda", dtype=torch.float16, requires_grad=True
2084
        )
2085
        for reentrant in (True, False):
2086
            with torch.autocast("cuda"):
2087
                output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
2088
            self.assertTrue(output.requires_grad)
2089
            self.assertTrue(output.dtype is torch.float16)
2090
            output.sum().backward()
2091

2092
    def test_cuda_autocast_deprecated_warning(self):
2093
        with self.assertWarnsRegex(
2094
            FutureWarning,
2095
            r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
2096
        ):
2097
            with torch.cuda.amp.autocast():
2098
                _ = torch.ones(10)
2099

2100
    @slowTest
2101
    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
2102
    @serialTest()
2103
    def test_max_large_axis(self):
2104
        x = torch.zeros(2**32, device="cuda", dtype=torch.int8)
2105
        x[-1] = 1
2106
        val, idx = x.max(0)
2107
        self.assertEqual(val, 1)
2108
        self.assertEqual(idx, x.shape[0] - 1)
2109

2110
    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2111
    def test_to_numpy(self):
2112
        self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy())
2113

2114
    def test_graph_is_current_stream_capturing(self):
2115
        self.assertFalse(torch.cuda.is_current_stream_capturing())
2116

2117
        if TEST_CUDA and (not TEST_WITH_ROCM):
2118
            s = torch.cuda.Stream()
2119
            with torch.cuda.stream(s):
2120
                g = torch.cuda.CUDAGraph()
2121
                self.assertFalse(torch.cuda.is_current_stream_capturing())
2122
                g.capture_begin()
2123
                self.assertTrue(torch.cuda.is_current_stream_capturing())
2124
                g.capture_end()
2125

2126
    @unittest.skipIf(
2127
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2128
    )
2129
    def test_graph_capture_simple(self):
2130
        s = torch.cuda.Stream()
2131

2132
        with torch.cuda.stream(s):
2133
            a = torch.full((1000,), 1, device="cuda")
2134
            g = torch.cuda.CUDAGraph()
2135
            torch.cuda.empty_cache()
2136
            g.capture_begin()
2137
            b = a
2138
            for _ in range(10):
2139
                b = b + 1
2140
            g.capture_end()
2141
        torch.cuda.current_stream().wait_stream(s)
2142

2143
        g.replay()
2144

2145
        self.assertTrue(b.sum().item() == 11000.0)
2146

2147
    @unittest.skipIf(
2148
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2149
    )
2150
    def test_graphsafe_set_get_rng_state(self):
2151
        # Define a function to create generator states, with optional graph registration
2152
        def create_states(generator):
2153
            """Initializes generator states and registers them with a CUDA graph if provided."""
2154
            # Ensure the CUDA generator is initialized
2155
            torch.rand(1, device="cuda")
2156
            generator.manual_seed(0)
2157

2158
            # Save the current state of the generator
2159
            old_state = generator.graphsafe_get_state()
2160
            # Create and save a cloned state of the generator
2161
            new_state = generator.clone_state()
2162
            # Return the original generator and its two states
2163
            return generator, old_state, new_state
2164

2165
        def register_states_to_graph(generator_state, graph):
2166
            generator, old_state, new_state = generator_state
2167
            graph.register_generator_state(old_state)
2168
            graph.register_generator_state(new_state)
2169

2170
        # Define a function to perform specific RNG actions using the generator's states
2171
        def perform_random_generation_steps(generator_state):
2172
            generator, old_state, new_state = generator_state
2173
            random_values = []
2174

2175
            # Generate random numbers with the new generator state
2176
            generator.graphsafe_set_state(new_state)
2177
            random_values.append(torch.rand(5, device="cuda", generator=generator))
2178

2179
            # Generate random numbers twice with the old generator state
2180
            generator.graphsafe_set_state(old_state)
2181
            random_values.extend(
2182
                [torch.rand(5, device="cuda", generator=generator) for _ in range(2)]
2183
            )
2184

2185
            return random_values
2186

2187
        # Define a function to retrieve the final offsets of the original and new generator states
2188
        def get_final_offsets_of_states(generator_state):
2189
            generator, old_state, new_state = generator_state
2190
            old_state_offset = old_state.get_offset()
2191
            new_state_offset = new_state.get_offset()
2192
            return old_state_offset, new_state_offset
2193

2194
        # Set up and test a new CUDA generator
2195
        generator = torch.Generator(device="cuda")
2196
        generator_state = create_states(generator)
2197

2198
        # Set up and test the default CUDA generator with a CUDA Graph
2199
        g = torch.cuda.CUDAGraph()
2200
        s = torch.cuda.Stream()
2201
        default_generator = torch.cuda.default_generators[0]
2202
        default_generator_state = create_states(default_generator)
2203
        register_states_to_graph(default_generator_state, g)
2204

2205
        # Perform random number generation within a CUDA graph
2206
        with torch.cuda.stream(s):
2207
            g.capture_begin()
2208
            graphed_random_values = perform_random_generation_steps(
2209
                default_generator_state
2210
            )
2211
            g.capture_end()
2212

2213
        # Synchronize the streams and replay the graph
2214
        torch.cuda.current_stream().wait_stream(s)
2215
        for _ in range(3):
2216
            random_values = perform_random_generation_steps(generator_state)
2217
            g.replay()
2218
            offset = get_final_offsets_of_states(generator_state)
2219
            graph_offset = get_final_offsets_of_states(default_generator_state)
2220

2221
            # Compare the final offsets of states for both generators to ensure consistency
2222
            self.assertTrue(offset == graph_offset)
2223
            # Compare the states generated outside and inside the graph
2224
            self.assertEqual(random_values, graphed_random_values)
2225

2226
    @unittest.skipIf(
2227
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2228
    )
2229
    def test_memory_stats_of_multiple_generators_and_graphs(self):
2230
        # Function to clear CUDA cache and collect garbage
2231
        def clear_cuda_cache():
2232
            gc.collect()
2233
            torch.cuda.empty_cache()
2234

2235
        # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph.
2236
        def simple_graph_task(graph):
2237
            s = torch.cuda.Stream()
2238
            with torch.cuda.stream(s):
2239
                graph.capture_begin()
2240
                torch.rand(1, device="cuda")
2241
                graph.capture_end()
2242
            torch.cuda.current_stream().wait_stream(s)
2243
            graph.replay()  # Replays the captured operations
2244

2245
        def get_memory_stats():
2246
            stats = torch.cuda.memory_stats()
2247
            num_blocks = stats["active.all.current"]
2248
            total_size = stats["active_bytes.all.current"]
2249
            return num_blocks, total_size
2250

2251
        def test(num_graphs, num_generators):
2252
            baseline = get_memory_stats()
2253
            baseline_num_blocks, baseline_total_size = baseline
2254

2255
            # Allocate CUDA graphs
2256
            graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)]
2257

2258
            # Allocate and manage generator states
2259
            default_generator = torch.cuda.default_generators[0]
2260
            generators = [default_generator.graphsafe_get_state()]
2261

2262
            # Starts from 1 as one state is already added
2263
            for _ in range(1, num_generators):
2264
                generators.append(default_generator.clone_state())
2265

2266
            for graph in graphs:
2267
                for generator_state in generators:
2268
                    graph.register_generator_state(generator_state)
2269
                simple_graph_task(graph)
2270

2271
            # Assert conditions after graph tasks
2272
            num_blocks, total_size = get_memory_stats()
2273
            # The allocated blocks should only be proportional to the number of generators
2274
            expected_blocks_diff = 2 * num_generators
2275
            expected_size_diff = 2 * 512 * num_generators  # Each block's size is 512
2276

2277
            self.assertTrue(
2278
                (num_blocks - baseline_num_blocks) == expected_blocks_diff,
2279
                "Unexpected number of active blocks.",
2280
            )
2281
            self.assertTrue(
2282
                (total_size - baseline_total_size) == expected_size_diff,
2283
                "Unexpected total memory size.",
2284
            )
2285

2286
            # Cleanup graphs and clear CUDA cache
2287
            while graphs:
2288
                graph = graphs.pop()
2289
                del graph
2290
            clear_cuda_cache()
2291

2292
            # Assert that memory stats return to baseline after cleanup
2293
            self.assertTrue(
2294
                get_memory_stats() == baseline,
2295
                "Memory stats do not match baseline after cleanup.",
2296
            )
2297

2298
        # Running the test function with different parameters
2299
        test(1, 1)
2300
        test(3, 2)
2301
        test(10, 20)
2302

2303
    @unittest.skipIf(
2304
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2305
    )
2306
    def test_graph_capture_reset_recapture(self):
2307
        s = torch.cuda.Stream()
2308

2309
        with torch.cuda.stream(s):
2310
            a = torch.full((1000,), 1, device="cuda")
2311
            g = torch.cuda.CUDAGraph()
2312
            torch.cuda.empty_cache()
2313
            g.capture_begin()
2314
            b = a
2315
            for _ in range(10):
2316
                b = b + 1
2317
            g.capture_end()
2318
        torch.cuda.current_stream().wait_stream(s)
2319

2320
        g.replay()
2321

2322
        self.assertTrue(b.sum().item() == 11000.0)
2323

2324
        g.reset()
2325

2326
        with torch.cuda.stream(s):
2327
            g.capture_begin()
2328
            b.fill_(2.0)
2329
            for _ in range(10):
2330
                b = b + 2
2331
            g.capture_end()
2332
        torch.cuda.current_stream().wait_stream(s)
2333

2334
        g.replay()
2335
        self.assertTrue(b.sum().item() == 22000.0)
2336

2337
        g.reset()
2338
        del g
2339

2340
    @unittest.skipIf(
2341
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2342
    )
2343
    def test_graph_debugdump(self):
2344
        torch.cuda.empty_cache()
2345
        x = torch.randn(10240000, device="cuda")
2346
        y = torch.rand_like(x)
2347
        g = torch.cuda.CUDAGraph()
2348
        g.enable_debug_mode()
2349
        s0 = torch.cuda.Stream()
2350
        s1 = torch.cuda.Stream()
2351
        s0.wait_stream(torch.cuda.current_stream())
2352
        with torch.cuda.stream(s0):
2353
            g.capture_begin()
2354
            z = x + y
2355
            with torch.cuda.stream(s1):
2356
                s1.wait_stream(s0)
2357
                w = z + y
2358
            s0.wait_stream(s1)
2359
            g.capture_end()
2360
        s0.synchronize()
2361
        torch.cuda.synchronize()
2362
        with tempfile.TemporaryDirectory() as tempdir:
2363
            g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot"))
2364

2365
    @unittest.skipIf(
2366
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2367
    )
2368
    def test_graph_error(self):
2369
        # We need to run this test in a separate thread as the error we trigger
2370
        # puts the cuda context in a bad state
2371
        script = """
2372
import torch
2373

2374
g = torch.cuda.CUDAGraph()
2375
try:
2376
    g.capture_begin()
2377
except RuntimeError as e:
2378
    if "CUDA graphs must be captured on a non-default stream." in str(e):
2379
        exit(0)
2380
    else:
2381
        exit(1)
2382
exit(2)
2383
"""
2384
        try:
2385
            a = subprocess.check_output(
2386
                [sys.executable, "-c", script],
2387
                stderr=subprocess.STDOUT,
2388
                # On Windows, opening the subprocess with the default CWD makes `import torch`
2389
                # fail, so just set CWD to this script's directory
2390
                cwd=os.path.dirname(os.path.realpath(__file__)),
2391
            )
2392
        except subprocess.CalledProcessError as e:
2393
            if e.returncode == 1:
2394
                self.assertTrue(
2395
                    False,
2396
                    "Error raise by starting capture without a stream is not the expected one",
2397
                )
2398
            elif e.returncode == 2:
2399
                self.assertTrue(
2400
                    False,
2401
                    "Error raised by starting capture without a stream was not caught",
2402
                )
2403

2404
    @unittest.skipIf(
2405
        (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11,
2406
        "CUDA >= 11.0 required for graphs",
2407
    )
2408
    def test_graph_warn_if_has_zero_nodes(self):
2409
        with warnings.catch_warnings(record=True) as caught:
2410
            g = torch.cuda.CUDAGraph()
2411
            s = torch.cuda.Stream()
2412
            with torch.cuda.stream(s):
2413
                g.capture_begin()
2414
                g.capture_end()
2415
        self.assertTrue(
2416
            any("The CUDA Graph is empty" in str(w.message) for w in caught)
2417
        )
2418

2419
    @unittest.skipIf(
2420
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2421
    )
2422
    @unittest.skipIf(
2423
        IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support"
2424
    )
2425
    def test_graph_capture_oom(self):
2426
        oom_regex = (
2427
            "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
2428
        )
2429
        with self.assertRaisesRegex(RuntimeError, oom_regex):
2430
            with torch.cuda.graph(torch.cuda.CUDAGraph()):
2431
                torch.zeros(2**40, device="cuda")
2432

2433
    @unittest.skipIf(
2434
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2435
    )
2436
    @serialTest()
2437
    def test_repeat_graph_capture_cublas_workspace_memory(self):
2438
        (x, y, z) = 1024, 512, 64
2439
        a = torch.rand((x, y), device="cuda")
2440
        b = torch.rand((y, z), device="cuda")
2441

2442
        # warmup
2443
        torch.mm(a, b)
2444

2445
        free_bytes_before, total_bytes = torch.cuda.mem_get_info()
2446
        used_gb_before = (total_bytes - free_bytes_before) / 1e9
2447

2448
        for i in range(100):
2449
            torch_graph = torch.cuda.CUDAGraph()
2450
            with torch.cuda.graph(torch_graph):
2451
                torch.mm(a, b)
2452
            torch_graph.replay()
2453

2454
        free_bytes_after, _ = torch.cuda.mem_get_info()
2455
        used_gb_after = (total_bytes - free_bytes_after) / 1e9
2456

2457
        self.assertFalse(used_gb_before + 0.1 < used_gb_after)
2458

2459
    @unittest.skipIf(
2460
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2461
    )
2462
    def test_graph_rng_functional(self):
2463
        ops_with_kwargs = (
2464
            (torch.nn.functional.dropout, {"p": 0.1}),
2465
            (torch.nn.functional.rrelu, {"training": True}),
2466
        )
2467
        size = 10000
2468

2469
        def run(op, kwargs):
2470
            a = torch.randn((size,), device="cuda", dtype=torch.float)
2471

2472
            # Control
2473
            torch.cuda.manual_seed(5)
2474
            eager_out = a
2475
            for _ in range(6):
2476
                eager_out = op(eager_out, **kwargs)
2477

2478
            graph_in = a.clone()
2479
            stream = torch.cuda.Stream()
2480
            stream.wait_stream(torch.cuda.current_stream())
2481
            with torch.cuda.stream(stream):
2482
                torch.cuda.manual_seed(5)
2483

2484
                g = torch.cuda.CUDAGraph()
2485
                torch.cuda.empty_cache()
2486
                g.capture_begin()
2487
                graph_out = graph_in
2488
                for _ in range(2):
2489
                    graph_out = op(graph_out, **kwargs)
2490
                g.capture_end()
2491
            torch.cuda.current_stream().wait_stream(stream)
2492

2493
            # Runs a graphed->eager->graphed sequence of RNG ops.
2494
            # replay() plays 2 invocations of the op, so the sequence has 6
2495
            # invocations total, matching Control.
2496
            # replay() reads from graph_in and writes to graph_out.
2497
            g.replay()
2498
            out = op(graph_out, **kwargs)
2499
            out = op(out, **kwargs)
2500
            graph_in.copy_(out)
2501
            g.replay()
2502

2503
            # If replay() updated RNG state correctly, graph_out
2504
            # should now hold data equal to eager_out.
2505
            try:
2506
                self.assertEqual(eager_out, graph_out)
2507
            except Exception as e:
2508
                raise RuntimeError("Failed on ", op) from e
2509

2510
            # Do the same operations varying seeds
2511
            seeds = [6, 128, 9999]
2512

2513
            for seed in seeds:
2514
                torch.cuda.manual_seed(seed)
2515
                graph_in.copy_(a)
2516
                for _ in range(3):
2517
                    g.replay()
2518

2519
                # If the random seed was not updated then the graph would
2520
                # generate the same output as in previous check.
2521
                try:
2522
                    self.assertNotEqual(eager_out, graph_out)
2523
                except Exception as e:
2524
                    raise RuntimeError("Failed on ", op) from e
2525

2526
                # Now repeat the same operations in non-graphed mode.
2527
                torch.cuda.manual_seed(seed)
2528
                for _ in range(3):
2529
                    eager_out.copy_(a)
2530
                    eager_out = op(eager_out, **kwargs)
2531
                    eager_out = op(eager_out, **kwargs)
2532

2533
                # In the end, graph_out and eager_out must be equal
2534
                # as they went under the same set of operations.
2535
                try:
2536
                    self.assertEqual(eager_out, graph_out)
2537
                except Exception as e:
2538
                    raise RuntimeError("Failed on ", op) from e
2539

2540
            # We hold references to all tensors used across streams up til this sync,
2541
            # so no need to call record_stream on those tensors.
2542
            torch.cuda.synchronize()
2543

2544
        for op, kwargs in ops_with_kwargs:
2545
            run(op, kwargs)
2546

2547
    @unittest.skipIf(
2548
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2549
    )
2550
    def test_graph_rng_distributions(self):
2551
        size = 10000
2552
        input = torch.rand((size,), device="cuda", dtype=torch.float)
2553
        alloc = torch.empty((size,), device="cuda", dtype=torch.float)
2554

2555
        # Torch ops to test with sample args (tuple) and kwargs (dict)
2556
        torch_with_args = (
2557
            ("bernoulli", (input.clone(),), {}),
2558
            # multinomial uses some uncapturable CUDA calls.
2559
            # TODO: reenable multinomial tests if/when the implementation is capturable.
2560
            # ("multinomial", (input.clone(), size, True), {}),
2561
            # ("multinomial", (input.clone(), size // 2, False), {}),
2562
            # TODO: reenable normal test, where std is a device
2563
            # tensor, when graph test failures are fixed
2564
            # ("normal", (input.clone() + 1, input.clone()), {}),
2565
            ("normal", (input.clone() + 1, 1.0), {}),
2566
            ("poisson", (input.clone(),), {}),
2567
            ("rand", (size,), {"device": "cuda", "dtype": torch.float}),
2568
            ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}),
2569
            ("randn", (size,), {"device": "cuda", "dtype": torch.float}),
2570
        )
2571

2572
        # Tensor methods to test with sample args (tuple)
2573
        tensor_with_args = (
2574
            ("bernoulli_", (input.clone(),)),
2575
            ("cauchy_", ()),
2576
            ("exponential_", ()),
2577
            ("geometric_", (0.3,)),
2578
            ("log_normal_", ()),
2579
            ("normal_", ()),
2580
            ("random_", ()),
2581
            ("uniform_", ()),
2582
        )
2583

2584
        def run(module, op, args, kwargs):
2585
            torch.cuda.manual_seed(5)
2586

2587
            # Each path runs a dummy op to increment the state a bit before creating controls.
2588
            if module == "torch":
2589
                dummy = getattr(torch, op)(*args, **kwargs)
2590
                control1 = getattr(torch, op)(*args, **kwargs)
2591
                control2 = getattr(torch, op)(*args, **kwargs)
2592
            else:
2593
                dummy = alloc.clone()
2594
                control1 = alloc.clone()
2595
                control2 = alloc.clone()
2596
                getattr(dummy, op)(*args)
2597
                getattr(control1, op)(*args)
2598
                getattr(control2, op)(*args)
2599

2600
            stream = torch.cuda.Stream()
2601
            stream.wait_stream(torch.cuda.current_stream())
2602
            with torch.cuda.stream(stream):
2603
                torch.cuda.manual_seed(5)
2604

2605
                g = torch.cuda.CUDAGraph()
2606
                torch.cuda.empty_cache()
2607
                if module == "torch":
2608
                    g.capture_begin()
2609
                    t1 = getattr(torch, op)(*args, **kwargs)
2610
                    t2 = getattr(torch, op)(*args, **kwargs)
2611
                    g.capture_end()
2612
                else:
2613
                    t1 = alloc.clone()
2614
                    t2 = alloc.clone()
2615
                    g.capture_begin()
2616
                    getattr(t1, op)(*args)
2617
                    getattr(t2, op)(*args)
2618
                    g.capture_end()
2619
            torch.cuda.current_stream().wait_stream(stream)
2620

2621
            if not TEST_CUDAMALLOCASYNC:
2622
                # Makes sure values haven't been populated yet
2623
                # (in other words, makes sure capture didn't actually run ops).
2624
                # We can only try this with the native allocator, for which captured
2625
                # addresses are already backed by cudaMalloced memory.
2626
                # If we try it with cudaMallocAsync, CUDA won't event consider
2627
                # the captured addresses allocated until replay(), and if we
2628
                # access them before replay() we get IMAs.
2629
                try:
2630
                    self.assertNotEqual(control1, t1)
2631
                    self.assertNotEqual(control2, t2)
2632
                except Exception as e:
2633
                    raise RuntimeError("Failed on " + module + "." + op) from e
2634

2635
            # Set a new seed to check if graph would use it
2636
            for seed in [6, 314, 271]:
2637
                torch.cuda.manual_seed(seed)
2638
                # Runs a dummy op prelude, as for controls, to make sure replay()
2639
                # picks up the dummy op's state increment.
2640
                if module == "torch":
2641
                    dummy = getattr(torch, op)(*args, **kwargs)
2642
                    control1 = getattr(torch, op)(*args, **kwargs)
2643
                    control2 = getattr(torch, op)(*args, **kwargs)
2644
                else:
2645
                    getattr(dummy, op)(*args)
2646
                    getattr(control1, op)(*args)
2647
                    getattr(control2, op)(*args)
2648

2649
                torch.cuda.manual_seed(seed)
2650
                if module == "torch":
2651
                    dummy = getattr(torch, op)(*args, **kwargs)
2652
                else:
2653
                    getattr(dummy, op)(*args)
2654

2655
                # see above comment on TEST_CUDAMALLOCASYNC
2656
                if not TEST_CUDAMALLOCASYNC:
2657
                    t1.copy_(alloc)
2658
                    t2.copy_(alloc)
2659

2660
                # Runs RNG ops that fill t1 and t2.
2661
                g.replay()
2662

2663
                try:
2664
                    self.assertEqual(control1, t1)
2665
                    self.assertEqual(control2, t2)
2666
                except Exception as e:
2667
                    raise RuntimeError("Failed on " + module + "." + op) from e
2668

2669
            # We hold references to all tensors used across streams up til this sync,
2670
            # so no need to call record_stream on those tensors.
2671
            torch.cuda.synchronize()
2672

2673
        for op_with_args in torch_with_args:
2674
            run("torch", *op_with_args)
2675

2676
        for meth_with_args in tensor_with_args:
2677
            # Adds an empty dict for kwargs, which none of the Tensor methods use
2678
            run("Tensor", *(meth_with_args + ({},)))
2679

2680
    @unittest.skipIf(
2681
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2682
    )
2683
    def test_graph_two_successive(self):
2684
        torch.cuda.empty_cache()
2685

2686
        size = 1000
2687
        kSmallBuffer = 2097152
2688

2689
        def func_with_temps(t, val):
2690
            x = t.clone() + val
2691
            y = t.clone() + val
2692
            return x + y
2693

2694
        s = torch.cuda.Stream()
2695

2696
        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2697
            g0 = torch.cuda.CUDAGraph()
2698
            g1 = torch.cuda.CUDAGraph()
2699

2700
            a = torch.ones((size,), device="cuda")
2701

2702
            s.wait_stream(torch.cuda.current_stream())
2703
            with torch.cuda.stream(s):
2704
                g0_args = (
2705
                    (torch.cuda.graph_pool_handle(),)
2706
                    if share_mem == "via graph_pool_handle()"
2707
                    else ()
2708
                )
2709
                g0.capture_begin(*g0_args)
2710
                b = a.clone()
2711
                for _ in range(5):
2712
                    b = func_with_temps(b, 1)
2713
                g0.capture_end()
2714

2715
                g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2716
                g1.capture_begin(*g1_args)
2717
                for _ in range(5):
2718
                    b = func_with_temps(b, 1)
2719
                g1.capture_end()
2720
            torch.cuda.current_stream().wait_stream(s)
2721

2722
            # mixes unrelated eager ops with replays
2723
            c = a.clone()
2724
            for _ in range(2):
2725
                c = func_with_temps(c, 3)
2726
            g0.replay()
2727
            for _ in range(2):
2728
                c = func_with_temps(c, 3)
2729
            g1.replay()
2730
            for _ in range(2):
2731
                c = func_with_temps(c, 3)
2732

2733
            self.assertEqual(b.sum().item(), size * 3070)
2734
            self.assertEqual(c.sum().item(), size * 442)
2735

2736
            if not TEST_CUDAMALLOCASYNC:
2737
                # These stat checks are specific to the native allocator.
2738
                if share_mem != "Don't share":
2739
                    self.assertEqual(
2740
                        reserved_no_sharing  # noqa: F821
2741
                        - torch.cuda.memory_stats()["reserved_bytes.all.current"],
2742
                        kSmallBuffer,
2743
                    )
2744
                else:
2745
                    reserved_no_sharing = torch.cuda.memory_stats()[
2746
                        "reserved_bytes.all.current"
2747
                    ]
2748

2749
            del a, b, c, g0, g1
2750
            # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
2751
            torch.cuda.synchronize()
2752
            torch.cuda.empty_cache()
2753

2754
    @unittest.skipIf(
2755
        (not TEST_CUDA_GRAPH)
2756
        or IS_WINDOWS
2757
        or (  # appears to still be broken on Windows as of 11.4+
2758
            torch.version.cuda
2759
            and int(torch.version.cuda.split(".")[0]) == 11
2760
            and int(torch.version.cuda.split(".")[1]) < 4
2761
        ),
2762
        "Graph bindings disallow concurrent replay for CUDA < 11.4, see "
2763
        + "https://github.com/pytorch/pytorch/pull/57556",
2764
    )
2765
    @unittest.skipIf(
2766
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2767
    )
2768
    def test_graph_concurrent_replay(self):
2769
        torch.cuda.empty_cache()
2770

2771
        size = 1000000  # largeish to help expose race conditions
2772

2773
        def func_with_temps(t, val):
2774
            x = t.clone() + val
2775
            y = t.clone() + val
2776
            return x + y
2777

2778
        s = torch.cuda.Stream()
2779

2780
        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2781
            g0 = torch.cuda.CUDAGraph()
2782
            g1 = torch.cuda.CUDAGraph()
2783

2784
            s0 = torch.cuda.Stream()
2785
            s1 = torch.cuda.Stream()
2786

2787
            a = torch.ones((size,), device="cuda")
2788

2789
            s.wait_stream(torch.cuda.current_stream())
2790
            with torch.cuda.stream(s):
2791
                g0_args = (
2792
                    (torch.cuda.graph_pool_handle(),)
2793
                    if share_mem == "via graph_pool_handle()"
2794
                    else ()
2795
                )
2796
                g0.capture_begin(*g0_args)
2797
                b = a.clone()
2798
                for _ in range(5):
2799
                    b = func_with_temps(b, 1)
2800
                g0.capture_end()
2801

2802
                g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2803
                g1.capture_begin(*g1_args)
2804
                c = a.clone()
2805
                for _ in range(5):
2806
                    c = func_with_temps(c, 2)
2807
                g1.capture_end()
2808

2809
            # To reproduce data corruption, I need g0 and g1's kernels to run concurrently.
2810
            # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead.
2811
            # The following pattern helps align device-side execution of g0 and g1's kernels.
2812
            torch.cuda.synchronize()
2813
            with torch.cuda.stream(s0):
2814
                torch.cuda._sleep(1000000)
2815
                s1.wait_stream(s0)
2816
                g0.replay()
2817
            with torch.cuda.stream(s1):
2818
                g1.replay()
2819
            torch.cuda.current_stream().wait_stream(s0)
2820
            torch.cuda.current_stream().wait_stream(s1)
2821

2822
            if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"):
2823
                # If we used the native allocator and shared mempools,
2824
                # we expect the concurrent replays corrupted each other.
2825
                self.assertNotEqual(b.sum().item(), size * 94)
2826
                self.assertNotEqual(c.sum().item(), size * 156)
2827
            else:
2828
                # If we EITHER
2829
                #   - used the native allocator without sharing mempools, OR
2830
                #   - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe
2831
                # we don't expect memory corruption.
2832
                self.assertEqual(b.sum().item(), size * 94)
2833
                self.assertEqual(c.sum().item(), size * 156)
2834

2835
            del a, b, c, g0, g1
2836
            # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them.
2837
            torch.cuda.synchronize()
2838
            torch.cuda.empty_cache()
2839

2840
    @unittest.skipIf(
2841
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2842
    )
2843
    def test_graph_three_successive(self):
2844
        torch.cuda.empty_cache()
2845

2846
        size = 1000
2847

2848
        s = torch.cuda.Stream()
2849

2850
        for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2851
            a = torch.ones((size,), device="cuda")
2852

2853
            g0 = torch.cuda.CUDAGraph()
2854
            g1 = torch.cuda.CUDAGraph()
2855
            g2 = torch.cuda.CUDAGraph()
2856

2857
            s.wait_stream(torch.cuda.current_stream())
2858
            with torch.cuda.stream(s):
2859
                g0_args = (
2860
                    (torch.cuda.graph_pool_handle(),)
2861
                    if share_mem == "via graph_pool_handle()"
2862
                    else ()
2863
                )
2864
                g0.capture_begin(*g0_args)
2865
                b = a.clone()
2866
                c = b + 1
2867
                d = b + 2
2868
                g0.capture_end()
2869

2870
                args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2871

2872
                g1.capture_begin(*args)
2873
                e = c + 3
2874
                del c
2875
                g1.capture_end()
2876

2877
                g2.capture_begin(*args)
2878
                f = d + 4
2879
                g2.capture_end()
2880
            torch.cuda.current_stream().wait_stream(s)
2881

2882
            # Tests that replaying in capture order is valid
2883
            g0.replay()
2884
            g1.replay()
2885
            g2.replay()
2886

2887
            self.assertEqual(e.sum().item(), size * 5)
2888
            self.assertEqual(f.sum().item(), size * 7)
2889

2890
            # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool
2891
            g0.replay()
2892
            g2.replay()
2893
            g1.replay()
2894

2895
            expect_corruption = (not TEST_CUDAMALLOCASYNC) and (
2896
                share_mem != "Don't share"
2897
            )
2898
            # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f.
2899
            # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3".
2900
            self.assertEqual(
2901
                e.sum().item(), size * (7 + 3) if expect_corruption else size * 5
2902
            )
2903
            self.assertEqual(f.sum().item(), size * 7)
2904

2905
            del a, b, d, e, f, g0, g1, g2
2906
            # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them.
2907
            torch.cuda.synchronize()
2908
            torch.cuda.empty_cache()
2909

2910
    @unittest.skipIf(
2911
        (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC,
2912
        "CUDA >= 11.0 or ROCM >= 5.3 required for graphs",
2913
    )
2914
    def test_graph_memory_stats_and_use_result_after_destroy_graph(self):
2915
        kSmallSize = 1048576
2916
        kSmallBuffer = 2097152
2917
        kLargeBuffer = 20971520
2918
        kMinLargeAlloc = 10485760
2919
        kRoundLarge = 2097152
2920

2921
        elem = 4
2922

2923
        # this was annoying to write but stresses the expectations pretty rigorously
2924
        cases = (
2925
            (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"),
2926
            (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"),
2927
            ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"),
2928
            (
2929
                (kMinLargeAlloc - 512) // elem,
2930
                2,
2931
                2 * kLargeBuffer,
2932
                kLargeBuffer,
2933
                "large_pool",
2934
            ),
2935
            (
2936
                (kMinLargeAlloc + 512) // elem,
2937
                3,
2938
                3
2939
                * (
2940
                    kRoundLarge
2941
                    * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge)
2942
                ),
2943
                kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge),
2944
                "large_pool",
2945
            ),
2946
        )
2947

2948
        stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.")
2949

2950
        gc.collect()
2951
        torch.cuda.empty_cache()
2952

2953
        s = torch.cuda.Stream()
2954

2955
        for (
2956
            numel,
2957
            delta_cudaMallocs,
2958
            delta_cudaMalloc_bytes,
2959
            delta_cudaMalloc_bytes_post_del_g,
2960
            pool_string,
2961
        ) in cases:
2962
            if pool_string == "small_pool":
2963
                delta_active_blocks = 3  # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders
2964
                delta_active_bytes = (
2965
                    numel * elem + 1024
2966
                )  # + 1024 for CUDAGraph's rng seed and offset holders each
2967
            else:
2968
                delta_active_blocks = 1  # We only check the large pool, which isn't affected by rng offset holder
2969
                delta_active_bytes = numel * elem
2970

2971
            g = torch.cuda.CUDAGraph()
2972
            s.wait_stream(torch.cuda.current_stream())
2973
            with torch.cuda.stream(s):
2974
                # Allocation stat estimates assume input is created on the same stream as capture_begin()
2975
                # (in other words, the same stream silo as the rng offset holder, which is not allocated from the
2976
                # capture's private pool).
2977
                a = torch.ones((numel,), device="cuda")
2978

2979
                precapture_stats = torch.cuda.memory_stats()
2980

2981
                g.capture_begin()
2982
                b = a.clone()
2983
                for _ in range(5):
2984
                    b = b.clone() + 1
2985
                g.capture_end()
2986
            torch.cuda.current_stream().wait_stream(s)
2987

2988
            gc.collect()
2989

2990
            postcapture_stats = torch.cuda.memory_stats()
2991

2992
            expecteds = (
2993
                delta_cudaMallocs,
2994
                delta_cudaMalloc_bytes,
2995
                delta_active_blocks,
2996
                delta_active_bytes,
2997
            )
2998
            # Double checks replay and stats before and after a call to empty_cache
2999
            for i in range(2):
3000
                for stat, expected in zip(stats_to_check, expecteds):
3001
                    stat = stat + pool_string + ".current"
3002
                    current = postcapture_stats[stat] - precapture_stats[stat]
3003

3004
                    # There will only ever be one expandable segment in each of the small and large pools. The way the
3005
                    # bookeeping is done in the allocator means that we never increment the number of segments.
3006
                    if self.expandable_segments and "segment" in stat:
3007
                        expected = 0
3008
                    # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
3009
                    # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
3010
                    # smaller than the page size
3011
                    if (
3012
                        self.expandable_segments
3013
                        and "reserved" in stat
3014
                        and (numel == cases[3][0] or numel == cases[4][0])
3015
                    ):
3016
                        expected = 2 * kLargeBuffer
3017

3018
                    self.assertEqual(
3019
                        current,
3020
                        expected,
3021
                        "Pre to post capture delta of "
3022
                        + stat
3023
                        + f" = {current}, expected = {expected}, numel = {numel}",
3024
                    )
3025

3026
                g.replay()
3027
                self.assertEqual(b.sum().item(), 6 * numel)
3028
                if i == 0:
3029
                    torch.cuda.empty_cache()
3030

3031
            del g
3032
            gc.collect()
3033
            torch.cuda.empty_cache()
3034
            postdel_stats = torch.cuda.memory_stats()
3035

3036
            # Uses graph result b after graph has been deleted
3037
            self.assertEqual(b.sum().item(), 6 * numel)
3038

3039
            # b should be the only live reference remaining from the graph's private pool
3040
            expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem)
3041
            for stat, expected in zip(stats_to_check, expecteds):
3042
                stat = stat + pool_string + ".current"
3043
                current = postdel_stats[stat] - precapture_stats[stat]
3044

3045
                # There will only ever be one expandable segment in each of the small and large pools. The way the
3046
                # bookeeping is done in the allocator means that we never increment the number of segments.
3047
                if self.expandable_segments and "segment" in stat:
3048
                    expected = 0
3049
                # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an
3050
                # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is
3051
                # smaller than the page size
3052
                if (
3053
                    self.expandable_segments
3054
                    and "reserved" in stat
3055
                    and numel == cases[3][0]
3056
                ):
3057
                    expected = 2 * kLargeBuffer
3058
                if (
3059
                    self.expandable_segments
3060
                    and "reserved" in stat
3061
                    and numel == cases[4][0]
3062
                ):
3063
                    expected = kLargeBuffer
3064

3065
                self.assertEqual(
3066
                    current,
3067
                    expected,
3068
                    "Pre capture to post graph delete delta of "
3069
                    + stat
3070
                    + f" = {current}, expected = {expected}, numel = {numel}",
3071
                )
3072

3073
            # del a, b before the next case is essential, otherwise overwriting a and b in the next case
3074
            # can throw off its allocation/deallocation counts.
3075
            del a, b
3076
            # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them.
3077
            torch.cuda.synchronize()
3078
            torch.cuda.empty_cache()
3079

3080
    @unittest.skipIf(
3081
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3082
    )
3083
    def test_graph_record_stream(self):
3084
        # Makes sure graph capture defers attempting to reclaim allocations used across streams. See
3085
        # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp
3086
        torch.cuda.empty_cache()
3087

3088
        potential_problem = torch.zeros((3,), device="cuda")
3089
        a = torch.zeros((3,), device="cuda")
3090
        s0 = torch.cuda.Stream()
3091
        s1 = torch.cuda.Stream()
3092
        s2 = torch.cuda.Stream()
3093
        g = torch.cuda.CUDAGraph()
3094

3095
        torch.cuda.synchronize()
3096
        with torch.cuda.stream(s0):
3097
            potential_problem.record_stream(s0)
3098
            torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES)
3099
            potential_problem.fill_(1.0)
3100
        del potential_problem
3101

3102
        with torch.cuda.stream(s1):
3103
            g.capture_begin()
3104
            # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc
3105
            # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life
3106
            # event, which will cause the capture to error.
3107
            b = a.clone()
3108

3109
            # Let's also see what happens if we record_stream on a tensor during capture.
3110
            s2.wait_stream(s1)
3111
            with torch.cuda.stream(s2):
3112
                b.fill_(1.0)
3113
                b.record_stream(s2)  # dummy record_stream
3114
                del b
3115
            s1.wait_stream(s2)
3116
            g.capture_end()
3117
        torch.cuda.synchronize()
3118

3119
        # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
3120
        c = torch.zeros((3,), device="cuda")
3121

3122
    @skipIfRocm
3123
    @unittest.skipIf(
3124
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3125
    )
3126
    # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize
3127
    # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior
3128
    # as a memory leak unless we skip the leak check.
3129
    @skipCUDAMemoryLeakCheckIf(True)
3130
    @serialTest()
3131
    def test_graph_cudnn_dropout(self):
3132
        # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp.
3133
        # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should
3134
        # avoid syncing noncapturing streams with captured events or vice versa.
3135
        torch.cuda.empty_cache()
3136

3137
        model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda()
3138
        x = torch.ones(100, 192, 512, device="cuda")
3139

3140
        y = model(x)
3141

3142
        g = torch.cuda.CUDAGraph()
3143
        s = torch.cuda.Stream()
3144
        s.wait_stream(torch.cuda.current_stream())
3145
        with torch.cuda.stream(s):
3146
            g.capture_begin()
3147
            y = model(x)
3148
            g.capture_end()
3149
        torch.cuda.current_stream().wait_stream(s)
3150

3151
        g.replay()
3152

3153
        y = model(x)
3154

3155
    @unittest.skipIf(
3156
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3157
    )
3158
    @parametrize(
3159
        "with_amp,cache_enabled,allow_unused_input",
3160
        [
3161
            subtest((False, False, True), decorators=[skipIfRocm]),
3162
            subtest((True, False, True), decorators=[skipIfRocm]),
3163
            subtest((True, True, True), decorators=[unittest.expectedFailure]),
3164
            subtest((False, False, False), decorators=[unittest.expectedFailure]),
3165
        ],
3166
        name_fn=lambda x, y, z: "{}{}{}".format(
3167
            {True: "with_amp", False: "without_amp"}[x],
3168
            {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "",
3169
            {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z],
3170
        ),
3171
    )
3172
    @serialTest()
3173
    def test_graph_make_graphed_callables(
3174
        self, with_amp, cache_enabled, allow_unused_input
3175
    ):
3176
        torch.manual_seed(5)
3177
        torch.cuda.manual_seed(5)
3178

3179
        N, D_in, H, D_out = 640, 4096, 2048, 1024
3180

3181
        class MLP1(torch.nn.Module):
3182
            def __init__(self, D_in: int, H: int, D_out: int):
3183
                super().__init__()
3184
                self.net_1 = torch.nn.Sequential(
3185
                    torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1)
3186
                ).cuda()
3187
                self.net_2 = torch.nn.Sequential(
3188
                    torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2)
3189
                ).cuda()
3190

3191
            def forward(self, input_dict: dict):
3192
                x = input_dict["x"]
3193
                return self.net_2(self.net_1(x))
3194

3195
        class MLP2(torch.nn.Module):
3196
            def __init__(self, D_in: int, H: int, D_out: int):
3197
                super().__init__()
3198
                self.net_1 = torch.nn.Sequential(
3199
                    torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1)
3200
                ).cuda()
3201
                self.net_2 = torch.nn.Sequential(
3202
                    torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2)
3203
                ).cuda()
3204

3205
            def forward(self, x):
3206
                return self.net_2(self.net_1(x))
3207

3208
        class ParameterlessModule(torch.nn.Module):
3209
            def forward(self, x):
3210
                idx = (
3211
                    torch.arange(x.size(0), device=x.device)
3212
                    .view(-1, 1)
3213
                    .repeat(1, x.size(1))
3214
                )
3215
                return {"output": torch.gather(x, 0, idx)}
3216

3217
        models = []
3218
        for _ in range(2):
3219
            model_section1 = MLP1(D_in, H, H).cuda()
3220
            model_section2 = MLP2(H, H, D_out).cuda()
3221
            model_section3 = ParameterlessModule().cuda()
3222
            models.append(
3223
                torch.nn.Sequential(model_section1, model_section2, model_section3)
3224
            )
3225

3226
        model_graphed = models[0]
3227
        model_control = models[1]
3228

3229
        model_graphed.load_state_dict(model_control.state_dict())
3230

3231
        opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1)
3232
        opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1)
3233

3234
        x = torch.randn(N, D_in, device="cuda")
3235
        h = torch.randn(N, H, device="cuda", requires_grad=True)
3236
        h2 = torch.randn(N, D_out, device="cuda", requires_grad=True)
3237
        unused_input = torch.randn(N, H, device="cuda", requires_grad=True)
3238
        y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True)
3239
        y = torch.randn(N, D_out, device="cuda")
3240

3241
        loss_fn_control = torch.nn.functional.mse_loss
3242
        relu_control = torch.nn.functional.relu
3243

3244
        # This is a good stress test. It graphs four callables: two Modules and two python functions.
3245
        with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
3246
            (
3247
                model_graphed[0],
3248
                model_graphed[1],
3249
                model_graphed[2],
3250
                relu_graphed,
3251
                loss_fn_graphed,
3252
            ) = torch.cuda.make_graphed_callables(
3253
                (
3254
                    model_graphed[0],
3255
                    model_graphed[1],
3256
                    model_graphed[2],
3257
                    relu_control,
3258
                    loss_fn_control,
3259
                ),
3260
                (
3261
                    ({"x": x, "unused_input": unused_input},),
3262
                    (h,),
3263
                    (h2,),
3264
                    (y_pred,),
3265
                    (y_pred, y),
3266
                ),
3267
                allow_unused_input=allow_unused_input,
3268
            )
3269

3270
        real_inputs = [torch.rand_like(x) for _ in range(10)]
3271
        real_targets = [torch.rand_like(y) for _ in range(10)]
3272

3273
        for m, opt, relu, loss_fn in zip(
3274
            (model_graphed, model_control),
3275
            (opt_graphed, opt_control),
3276
            (relu_graphed, relu_control),
3277
            (loss_fn_graphed, loss_fn_control),
3278
        ):
3279
            # Resets RNC states before iterations for graphed and ungraphed models,
3280
            # so dropout math should be bitwise identical for both.
3281
            torch.manual_seed(5)
3282
            torch.cuda.manual_seed(5)
3283
            for data, target in zip(real_inputs, real_targets):
3284
                opt.zero_grad(set_to_none=True)
3285
                with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
3286
                    y_pred = m({"x": data, "unused_input": unused_input})["output"]
3287
                    y_pred = relu(y_pred)
3288
                    loss = loss_fn(y_pred, target)
3289
                    loss.backward()
3290
                opt.step()
3291

3292
        for p, pc in zip(model_graphed.parameters(), model_control.parameters()):
3293
            self.assertEqual(p, pc)
3294

3295
        # We graphed the models in training mode. Eval should still run ungraphed.
3296
        model_graphed.eval()
3297
        model_control.eval()
3298
        self.assertEqual(
3299
            model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
3300
        )
3301

3302
    @unittest.skipIf(
3303
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3304
    )
3305
    @parametrize(
3306
        "with_amp,cache_enabled,allow_unused_input",
3307
        [
3308
            subtest((False, False, True), decorators=[skipIfRocm]),
3309
            subtest((True, False, True), decorators=[skipIfRocm]),
3310
            subtest((True, True, True), decorators=[unittest.expectedFailure]),
3311
            subtest((False, False, False), decorators=[skipIfRocm]),
3312
        ],
3313
        name_fn=lambda x, y, z: "{}{}{}".format(
3314
            {True: "with_amp", False: "without_amp"}[x],
3315
            {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "",
3316
            {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z],
3317
        ),
3318
    )
3319
    @serialTest()
3320
    def test_graph_make_graphed_callables_parameterless_nograd_module(
3321
        self, with_amp, cache_enabled, allow_unused_input
3322
    ):
3323
        torch.manual_seed(5)
3324
        torch.cuda.manual_seed(5)
3325

3326
        N, D_in, H, D_out = 640, 4096, 2048, 1024
3327

3328
        class ParameterlessModule(torch.nn.Module):
3329
            def forward(self, input_dict: dict):
3330
                x = input_dict["x"]
3331
                idx = (
3332
                    torch.arange(x.size(0), device=x.device)
3333
                    .view(-1, 1)
3334
                    .repeat(1, x.size(1))
3335
                )
3336
                return {"output": torch.gather(x, 0, idx)}
3337

3338
        models = []
3339
        for _ in range(2):
3340
            model_section1 = ParameterlessModule().cuda()
3341
            models.append(torch.nn.Sequential(model_section1))
3342

3343
        model_graphed = models[0]
3344
        model_control = models[1]
3345

3346
        model_graphed.load_state_dict(model_control.state_dict())
3347

3348
        x = torch.randn(N, D_in, device="cuda", requires_grad=False)
3349
        unused_input = torch.randn(N, H, device="cuda", requires_grad=False)
3350
        y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False)
3351
        y = torch.randn(N, D_in, device="cuda")
3352

3353
        # This is a good stress test. It graphs four callables: two Modules and two python functions.
3354
        with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
3355
            model_graphed[0] = torch.cuda.make_graphed_callables(
3356
                model_graphed[0],
3357
                ({"x": x, "unused_input": unused_input},),
3358
                allow_unused_input=allow_unused_input,
3359
            )
3360

3361
        real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)]
3362
        real_targets = [torch.rand_like(y) for _ in range(10)]
3363

3364
        for m in (model_graphed, model_control):
3365
            # Resets RNC states before iterations for graphed and ungraphed models,
3366
            # so dropout math should be bitwise identical for both.
3367
            torch.manual_seed(5)
3368
            torch.cuda.manual_seed(5)
3369
            for data, target in zip(real_inputs, real_targets):
3370
                with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
3371
                    out = m({"x": data, "unused_input": unused_input})["output"]
3372

3373
        # We graphed the models in training mode. Eval should still run ungraphed.
3374
        model_graphed.eval()
3375
        model_control.eval()
3376
        self.assertEqual(
3377
            model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
3378
        )
3379

3380
    @unittest.skipIf(
3381
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3382
    )
3383
    def test_graph_make_graphed_callables_same_pool(self):
3384
        torch.manual_seed(5)
3385
        torch.cuda.manual_seed(5)
3386
        models = []
3387
        num_models = 3
3388
        for _ in range(num_models):
3389
            models.append(
3390
                torch.nn.Sequential(
3391
                    torch.nn.Linear(32, 128),
3392
                    torch.nn.ReLU(),
3393
                    torch.nn.Linear(128, 128),
3394
                ).cuda()
3395
            )
3396
        # we will reuse the same pool for all graph captures
3397
        mempool = torch.cuda.graph_pool_handle()
3398
        graphed_models = []
3399
        for model in models:
3400
            x = torch.randn([64, 32], device="cuda")
3401
            graphed_model = deepcopy(model)
3402
            graphed_model = torch.cuda.make_graphed_callables(
3403
                graphed_model, (x,), pool=mempool
3404
            )
3405
            graphed_models.append(graphed_model)
3406

3407
        for model, graphed_model in zip(models, graphed_models):
3408
            x = torch.randn([64, 32], device="cuda")
3409
            y = model(x)
3410
            yg = graphed_model(x)
3411
            l = y.norm()
3412
            lg = yg.norm()
3413
            l.backward()
3414
            lg.backward()
3415

3416
            self.assertEqual(y, yg)
3417
            self.assertEqual(l, lg)
3418
            for p, pg in zip(model.parameters(), graphed_model.parameters()):
3419
                self.assertEqual(p, pg)
3420
                self.assertEqual(p.grad, pg.grad)
3421
                self.assertNotEqual(p.data_ptr(), pg.data_ptr())
3422
                self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr())
3423

3424
    def _test_graphed_optimizer(
3425
        self, steps_warmup, steps_train, optimizer_ctor, kwargs
3426
    ):
3427
        for actually_do_graphs in (True, False):
3428
            params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [
3429
                torch.randn((), device="cuda")
3430
            ]
3431
            params_control = [p.clone().requires_grad_() for p in params]
3432
            params_graphed = [p.clone().requires_grad_() for p in params]
3433

3434
            grads = [
3435
                [torch.randn_like(p) for p in params]
3436
                for _ in range(steps_warmup + steps_train)
3437
            ]
3438

3439
            # Control (capturable=False)
3440

3441
            opt = optimizer_ctor(params_control, capturable=False, **kwargs)
3442

3443
            for i in range(steps_warmup + steps_train):
3444
                for j, p in enumerate(params_control):
3445
                    p.grad = grads[i][j]
3446
                opt.step()
3447

3448
            # capturable=True
3449

3450
            opt = optimizer_ctor(params_graphed, capturable=True, **kwargs)
3451

3452
            for i in range(steps_warmup):
3453
                for j, p in enumerate(params_graphed):
3454
                    p.grad = grads[i][j]
3455
                opt.step()
3456

3457
            if actually_do_graphs:
3458
                g = torch.cuda.CUDAGraph()
3459
                with torch.cuda.graph(g):
3460
                    opt.step()
3461

3462
            for i in range(steps_train):
3463
                if actually_do_graphs:
3464
                    for j, p in enumerate(params_graphed):
3465
                        p.grad.copy_(grads[i + steps_warmup][j])
3466
                    g.replay()
3467
                else:
3468
                    # Passing capturable=True to the constructor and running without graphs should still be
3469
                    # numerically correct, even if it's not ideal for performance.
3470
                    for j, p in enumerate(params_graphed):
3471
                        p.grad = grads[i + steps_warmup][j]
3472
                    opt.step()
3473

3474
            for p_control, p_graphed in zip(params_control, params_graphed):
3475
                self.assertEqual(p_control, p_graphed)
3476

3477
    @unittest.skipIf(
3478
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3479
    )
3480
    def test_graph_optims_with_explicitly_capturable_param_groups(self):
3481
        # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__
3482
        n_warmup, n_replay = 3, 2
3483
        for optimizer, second_param_group_capturable in product(
3484
            (
3485
                torch.optim.Adam,
3486
                torch.optim.AdamW,
3487
                torch.optim.ASGD,
3488
                torch.optim.Adamax,
3489
                torch.optim.NAdam,
3490
                torch.optim.RAdam,
3491
                torch.optim.Adadelta,
3492
                torch.optim.RMSprop,
3493
                torch.optim.Rprop,
3494
            ),
3495
            (True, False),
3496
        ):
3497
            ref_p1, param1 = (
3498
                torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2)
3499
            )
3500
            ref_p2, param2 = (
3501
                torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2)
3502
            )
3503
            grads1, grads2 = (
3504
                [torch.randn_like(param1) for _ in range(n_warmup + n_replay)]
3505
                for _ in range(2)
3506
            )
3507
            ref_grads1, ref_grads2 = (
3508
                [t.clone() for t in tensors] for tensors in (grads1, grads2)
3509
            )
3510
            params = [
3511
                {"params": [param1], "capturable": True},
3512
                {"params": [param2], "capturable": second_param_group_capturable},
3513
            ]
3514
            opt = optimizer(params)
3515
            opt_ = optimizer(
3516
                [
3517
                    {"params": [ref_p1], "capturable": False},
3518
                    {"params": [ref_p2], "capturable": False},
3519
                ]
3520
            )
3521

3522
            for i in range(n_warmup + n_replay):
3523
                ref_p1.grad = ref_grads1[i]
3524
                ref_p2.grad = ref_grads2[i]
3525
                opt_.step()
3526

3527
            for i in range(n_warmup):
3528
                param1.grad = grads1[i]
3529
                param2.grad = grads2[i]
3530
                opt.step()
3531

3532
            g = torch.cuda.CUDAGraph()
3533
            if not second_param_group_capturable:
3534
                with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"):
3535
                    with torch.cuda.graph(g):
3536
                        opt.step()
3537
            else:
3538
                with torch.cuda.graph(g):
3539
                    opt.step()
3540

3541
                for i in range(n_replay):
3542
                    param1.grad.copy_(grads1[n_warmup + i])
3543
                    param2.grad.copy_(grads2[n_warmup + i])
3544
                    g.replay()
3545
                self.assertEqual(ref_p1, param1)
3546
                self.assertEqual(ref_p2, param2)
3547

3548
    @unittest.skipIf(
3549
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3550
    )
3551
    def test_cuda_graph_error_options(self):
3552
        def fn():
3553
            x = torch.zeros([2000], device="cuda")
3554
            y = x + x + x
3555
            return y
3556

3557
        mem = None
3558

3559
        def raw_malloc():
3560
            global mem
3561
            mem = None
3562
            stream = torch.cuda.Stream()
3563
            try:
3564
                with torch.cuda.stream(stream):
3565
                    mem = torch.cuda.caching_allocator_alloc(1024)
3566
            except BaseException:
3567
                if mem is None:
3568
                    return
3569
            try:
3570
                torch.cuda.caching_allocator_delete(mem)
3571
                mem = None
3572
                return None
3573
            except BaseException:
3574
                pass
3575

3576
        def throws_on_cuda_event(capture_error_mode):
3577
            graph = torch.cuda.CUDAGraph()
3578
            torch.cuda.synchronize()
3579
            stream = torch.cuda.Stream()
3580
            stream.wait_stream(torch.cuda.current_stream())
3581
            with torch.cuda.stream(stream):
3582
                fn()
3583
            stream.synchronize()
3584
            torch.cuda.current_stream().wait_stream(stream)
3585
            torch.cuda.synchronize()
3586
            try:
3587
                with torch.cuda.graph(
3588
                    graph, stream=stream, capture_error_mode=capture_error_mode
3589
                ):
3590
                    out = fn()
3591
                    thread = threading.Thread(target=raw_malloc)
3592
                    thread.start()
3593
                    thread.join()
3594
            except Exception:
3595
                if mem is not None:
3596
                    torch.cuda.caching_allocator_delete(mem)
3597
                return True
3598

3599
            return False
3600

3601
        self.assertFalse(throws_on_cuda_event("thread_local"))
3602
        self.assertFalse(throws_on_cuda_event("relaxed"))
3603

3604
        # Exception would Corrupt Process and make other tests fail
3605
        # self.assertTrue(throws_on_cuda_event("global"))
3606

3607
    @unittest.skipIf(
3608
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3609
    )
3610
    def test_cuda_graph_allocator_propagates_stream(self):
3611
        segments = torch.cuda.memory_snapshot()
3612
        existing_pools = {s["segment_pool_id"] for s in segments}
3613
        x = torch.randn(10240000, device="cuda")
3614
        y = torch.rand_like(x)
3615
        g = torch.cuda.CUDAGraph()
3616
        s0 = torch.cuda.Stream()
3617
        s1 = torch.cuda.Stream()
3618
        s0.wait_stream(torch.cuda.current_stream())
3619
        with torch.cuda.stream(s0):
3620
            g.capture_begin()
3621
            z = x + y
3622
        with torch.cuda.stream(s1):
3623
            s1.wait_stream(s0)
3624
            w = z + y
3625
        s0.wait_stream(s1)
3626
        with torch.cuda.stream(s0):
3627
            g.capture_end()
3628
        segments = torch.cuda.memory_snapshot()
3629
        x = [
3630
            s["segment_pool_id"]
3631
            for s in segments
3632
            if s["segment_pool_id"] not in existing_pools
3633
        ]
3634
        self.assertEqual(len(x), 2)
3635
        self.assertEqual(x[0], x[1])
3636

3637
    def test_batch_norm_gather_stats(self):
3638
        input = torch.randn(1, 3, 3, 3, device="cuda")
3639
        mean, invstd = torch.batch_norm_gather_stats(
3640
            input,
3641
            mean=torch.ones(2, 3, device="cuda"),
3642
            invstd=torch.ones(2, 3, device="cuda"),
3643
            running_mean=None,
3644
            running_var=None,
3645
            momentum=0.1,
3646
            eps=1e-5,
3647
            count=2,
3648
        )
3649
        self.assertEqual(mean, torch.ones(3, device="cuda"))
3650
        self.assertEqual(invstd, torch.ones(3, device="cuda"))
3651

3652
    def test_matmul_memory_use(self):
3653
        def get_max_used():
3654
            torch.cuda.synchronize()
3655
            val = torch.cuda.max_memory_allocated()
3656
            torch.cuda.reset_peak_memory_stats()
3657
            return val
3658

3659
        a = torch.rand(1, 32, 32, device="cuda")
3660
        b = torch.rand(24, 32, 1, device="cuda")
3661

3662
        get_max_used()
3663

3664
        torch.matmul(a, b)
3665

3666
        matmul_mem = get_max_used()
3667

3668
        a = a.expand(24, 32, 32)
3669
        torch.matmul(a, b)
3670

3671
        matmul_expand_mem = get_max_used()
3672

3673
        torch.bmm(a, b)
3674

3675
        bmm_mem = get_max_used()
3676

3677
        self.assertEqual(matmul_expand_mem, matmul_mem)
3678
        self.assertEqual(bmm_mem, matmul_mem)
3679

3680
    @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test")
3681
    def test_rocm_backward_pass_guard(self):
3682
        # The test exercises a ROCm-specific feature.
3683

3684
        class MyFunction(torch.autograd.Function):
3685
            @staticmethod
3686
            def forward(ctx, tensor, constant):
3687
                self.assertFalse(torch._C._rocm_is_backward_pass())
3688
                ctx.constant = constant
3689
                return tensor * constant
3690

3691
            @staticmethod
3692
            def backward(ctx, grad_output):
3693
                self.assertTrue(torch._C._rocm_is_backward_pass())
3694
                return grad_output * ctx.constant, None
3695

3696
        class MyModule(torch.nn.Module):
3697
            def __init__(self) -> None:
3698
                super().__init__()
3699
                self.a = torch.nn.Parameter(torch.randn(()))
3700

3701
            def forward(self, x):
3702
                return MyFunction.apply(x, self.a)
3703

3704
        model = MyModule()
3705
        criterion = torch.nn.MSELoss(reduction="sum")
3706
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
3707

3708
        x = torch.randn(5, 5)
3709
        result = model(x)
3710
        loss = criterion(result, x)
3711
        optimizer.zero_grad()
3712
        loss.backward()
3713
        optimizer.step()
3714

3715
    def test_matmul_device_mismatch(self):
3716
        cpu = torch.rand((10, 10))
3717
        cuda = cpu.cuda()
3718
        with self.assertRaisesRegex(
3719
            RuntimeError, "Expected all tensors to be on the same device"
3720
        ):
3721
            cpu @ cuda
3722
        with self.assertRaisesRegex(
3723
            RuntimeError, "Expected all tensors to be on the same device"
3724
        ):
3725
            cuda @ cpu
3726

3727
        for s, m1, m2 in product((cpu, cuda), repeat=3):
3728
            if s.device == m1.device == m2.device:
3729
                torch.addmm(s, m1, m2)
3730
            else:
3731
                with self.assertRaisesRegex(
3732
                    RuntimeError, "Expected all tensors to be on the same device"
3733
                ):
3734
                    torch.addmm(s, m1, m2)
3735

3736
    @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient")
3737
    def test_lazy_init(self):
3738
        """Validate that no CUDA calls are made during `import torch` call"""
3739

3740
        def check_output(script: str) -> str:
3741
            return (
3742
                subprocess.check_output([sys.executable, "-c", script])
3743
                .decode("ascii")
3744
                .strip()
3745
            )
3746

3747
        VISIBLE_DEVICES = (
3748
            "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES"
3749
        )
3750
        test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())"
3751
        rc = check_output(test_script)
3752
        self.assertEqual(rc, "0")
3753
        if not TEST_WITH_ROCM:
3754
            # Check that `cuInit` was not called during the import
3755
            # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3
3756
            # See https://github.com/pytorch/pytorch/issues/116276 for more details
3757
            libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll"
3758
            cuda_driver_api_call = (
3759
                f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))"
3760
            )
3761
            rc = check_output(
3762
                f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})"
3763
            )
3764
            self.assertEqual(rc, "3")
3765

3766
    @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing")
3767
    def test_hip_device_count(self):
3768
        """Validate device_count works with both CUDA/HIP visible devices"""
3769
        test_script = """\
3770
import torch
3771
import os
3772
print(f"{torch.cuda.device_count()}")
3773
"""
3774
        custom_envs = [
3775
            {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
3776
            {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"},
3777
            {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"},
3778
        ]
3779

3780
        for env_config in custom_envs:
3781
            env = os.environ.copy()
3782
            for key, value in env_config.items():
3783
                if value is None:
3784
                    env.pop(key, None)
3785
                else:
3786
                    env[key] = value
3787
            r = (
3788
                subprocess.check_output([sys.executable, "-c", test_script], env=env)
3789
                .decode("ascii")
3790
                .strip()
3791
            )
3792
            self.assertEqual("1", r)
3793

3794
    @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices")
3795
    def test_device_count_not_cached_pre_init(self):
3796
        visible_devices = (
3797
            "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES"
3798
        )
3799
        test_script = f"""\
3800
import torch
3801
import os
3802
r1 = torch.cuda.device_count()
3803
os.environ['{visible_devices}'] = '0'
3804
r2 = torch.cuda.device_count()
3805
torch.empty(10, device='cuda')
3806
print(f"{{r1}}, {{r2}}")
3807
"""
3808

3809
        r = (
3810
            subprocess.check_output([sys.executable, "-c", test_script])
3811
            .decode("ascii")
3812
            .strip()
3813
        )
3814

3815
        x = torch.cuda.device_count()
3816
        self.assertEqual(f"{x}, 1", r)
3817

3818
    @unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
3819
    def test_gds_fails_in_ci(self):
3820
        if IS_WINDOWS or TEST_WITH_ROCM:
3821
            error_msg = "is not supported on this platform"
3822
        else:
3823
            error_msg = "cuFileHandleRegister failed"
3824
        with TemporaryFileName() as f:
3825
            with self.assertRaisesRegex(RuntimeError, error_msg):
3826
                file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
3827

3828

3829
@torch.testing._internal.common_utils.markDynamoStrictTest
3830
class TestCudaMallocAsync(TestCase):
3831
    @unittest.skipIf(
3832
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3833
    )
3834
    def test_memory_snapshot(self):
3835
        try:
3836
            torch.cuda.memory.empty_cache()
3837
            torch.cuda.memory._record_memory_history("state", stacks="python")
3838
            # make x the second block in a segment
3839
            torch.rand(2 * 311, 411, device="cuda")
3840
            unused = torch.rand(310, 410, device="cuda")
3841
            x = torch.rand(311, 411, device="cuda")
3842

3843
            # create a bunch of tensors that all will tile into the
3844
            # same segment to  exercise the history merging code
3845
            # 512B is the minimum block size,
3846
            # so we allocate all the tensors to this size to make sure
3847
            # they tile evenly
3848
            tensors = [torch.rand(128, device="cuda") for _ in range(1000)]
3849
            while tensors:
3850
                del tensors[randint(0, len(tensors) - 1)]
3851

3852
            # exercise the history trimming code
3853
            torch.rand(128 * 5, device="cuda")
3854

3855
            ss = torch.cuda.memory._snapshot()
3856
            found_it = False
3857
            for seg in ss["segments"]:
3858
                self.assertTrue("frames" in seg)
3859
                for b in seg["blocks"]:
3860
                    if b["requested_size"] == 311 * 411 * 4:
3861
                        self.assertTrue("test_cuda" in b["frames"][0]["filename"])
3862
                        found_it = True
3863
                        self.assertEqual(x.untyped_storage().data_ptr(), b["address"])
3864
            self.assertTrue(found_it)
3865

3866
            if not IS_WINDOWS:
3867
                with tempfile.NamedTemporaryFile() as f:
3868
                    torch.cuda.memory._save_segment_usage(f.name)
3869
                    with open(f.name) as f2:
3870
                        self.assertTrue("test_cuda.py" in f2.read())
3871
            del unused
3872
            del x
3873
            torch.cuda.empty_cache()
3874
            ss = torch.cuda.memory._snapshot()
3875
            self.assertTrue(
3876
                ss["device_traces"][0][-1]["action"]
3877
                in ("segment_free", "segment_unmap")
3878
            )
3879

3880
        finally:
3881
            torch.cuda.memory._record_memory_history(None)
3882

3883
    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
3884
    def test_direct_traceback(self):
3885
        from torch._C._profiler import gather_traceback, symbolize_tracebacks
3886

3887
        c = gather_traceback(True, True, True)
3888
        (r,) = symbolize_tracebacks([c])
3889
        r = str(r)
3890
        self.assertTrue("test_cuda.py" in r)
3891
        self.assertTrue("unwind" in r)
3892

3893
    @unittest.skipIf(
3894
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3895
    )
3896
    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3897
    def test_memory_snapshot_with_cpp(self):
3898
        try:
3899
            torch.cuda.memory.empty_cache()
3900
            torch.cuda.memory._record_memory_history("state", stacks="all")
3901
            x = torch.rand(311, 411, device="cuda")
3902

3903
            ss = torch.cuda.memory._snapshot()["segments"]
3904
            found_it = False
3905
            for seg in ss:
3906
                for b in seg["blocks"]:
3907
                    if b["requested_size"] == 311 * 411 * 4:
3908
                        self.assertTrue("::rand" in str(b["frames"]))
3909
                        found_it = True
3910
            self.assertTrue(found_it)
3911

3912
        finally:
3913
            torch.cuda.memory._record_memory_history(None)
3914

3915
    @skipIfRocm
3916
    def test_memory_profiler_viz(self):
3917
        with torch.profiler.profile(
3918
            with_stack=True, profile_memory=True, record_shapes=True
3919
        ) as prof:
3920
            x = torch.rand(128, 128, device="cuda")
3921
            x * x + x * x
3922
        plot = profile_plot(prof)
3923
        plot = json.dumps(_profile_to_snapshot(prof))
3924
        self.assertTrue("test_cuda.py" in plot)
3925
        self.assertTrue("test_memory_profiler_viz" in plot)
3926
        self.assertTrue("category" in plot)
3927

3928
    @unittest.skipIf(
3929
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3930
    )
3931
    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3932
    def test_cycles(self):
3933
        fired = False
3934

3935
        def observer(html):
3936
            nonlocal fired
3937
            fired = True
3938
            self.assertTrue("torch.Tensor" in html)
3939
            self.assertTrue("test_cuda" in html)
3940
            self.assertTrue("cell_contents" in html)
3941

3942
        disarm = observe_tensor_cycles(observer)
3943

3944
        def noop():
3945
            pass
3946

3947
        try:
3948

3949
            def create():
3950
                x = torch.empty(3, 4, device="cuda")
3951

3952
                def foo(p):
3953
                    if p:
3954
                        return foo(not p)
3955
                    else:
3956
                        return x
3957

3958
                return foo
3959

3960
            create()
3961
            gc.collect()
3962
            # the callback has to run outside of the collect
3963
            # call so it doesn't actual fire until the next
3964
            # method call after a gc.collect
3965
            noop()
3966
            self.assertTrue(fired)
3967
        finally:
3968
            disarm()
3969

3970
    @unittest.skipIf(
3971
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3972
    )
3973
    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3974
    def test_memory_plots(self):
3975
        for context, stacks in (
3976
            ("all", "all" if IS_LINUX else "python"),
3977
            ("all", "python"),
3978
            (None, "python"),
3979
        ):
3980
            try:
3981
                torch.cuda.memory.empty_cache()
3982
                torch.cuda.memory._record_memory_history(
3983
                    "all", context=context, stacks=stacks
3984
                )
3985

3986
                def run():
3987
                    x = torch.rand(128, 128, device="cuda")
3988
                    x * x + x * x
3989

3990
                run()
3991
                cpp = stacks == "all"
3992
                record_context = context is not None
3993
                ss = torch.cuda.memory._snapshot()
3994

3995
                tplot = trace_plot(ss)
3996
                splot = segment_plot(ss)
3997
                text = json.dumps(ss)
3998

3999
                self.assertTrue(record_context == ("test_memory_plots" in text))
4000
                self.assertTrue(cpp == ("::rand" in text))
4001
                self.assertTrue(str(128 * 128 * 4) in text)
4002

4003
            finally:
4004
                torch.cuda.memory._record_memory_history(None)
4005

4006
    @unittest.skipIf(
4007
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4008
    )
4009
    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
4010
    def test_memory_plots_free_stack(self):
4011
        for context in ["alloc", "all", "state"]:
4012
            try:
4013
                torch.cuda.memory.empty_cache()
4014
                torch.cuda.memory._record_memory_history(context=context)
4015
                x = None
4016

4017
                def thealloc():
4018
                    nonlocal x
4019
                    x = torch.rand(3, 4, device="cuda")
4020

4021
                def thefree():
4022
                    nonlocal x
4023
                    del x
4024

4025
                thealloc()
4026
                thefree()
4027
                ss = json.dumps(torch.cuda.memory._snapshot())
4028
                self.assertTrue(("thefree" in ss) == (context == "all"))
4029
                self.assertTrue(("thealloc" in ss) == (context != "state"))
4030
            finally:
4031
                torch.cuda.memory._record_memory_history(None)
4032

4033
    @unittest.skipIf(
4034
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4035
    )
4036
    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
4037
    def test_memory_plots_history_context(self):
4038
        try:
4039
            torch.cuda.memory.empty_cache()
4040
            x = None
4041

4042
            def should_capture1():
4043
                nonlocal x
4044
                x = torch.rand(4, 4, device="cuda")
4045

4046
            def should_not_capture():
4047
                nonlocal x
4048
                x = torch.rand(3, 4, device="cuda")
4049

4050
            def should_capture2():
4051
                nonlocal x
4052
                x = torch.rand(4, 4, device="cuda")
4053

4054
            # Recording with context and python call stacks should capture the call stack.
4055
            torch.cuda.memory._record_memory_history(context="all", stacks="python")
4056
            should_capture1()
4057
            # Recording with context=None should not capture the call stack.
4058
            torch.cuda.memory._record_memory_history(context=None)
4059
            should_not_capture()
4060
            # Recording with context and python call stacks should capture the call stack.
4061
            torch.cuda.memory._record_memory_history(context="all", stacks="python")
4062
            should_capture2()
4063

4064
            ss = json.dumps(torch.cuda.memory._snapshot())
4065
            self.assertTrue("should_capture1" in ss)
4066
            self.assertTrue("should_not_capture" not in ss)
4067
            self.assertTrue("should_capture2" in ss)
4068
        finally:
4069
            torch.cuda.memory._record_memory_history(None)
4070

4071
    @unittest.skipIf(
4072
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4073
    )
4074
    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
4075
    def test_memory_plots_free_segment_stack(self):
4076
        for context in ["alloc", "all", "state"]:
4077
            try:
4078
                torch.cuda.memory.empty_cache()
4079
                torch.cuda.memory._record_memory_history(context=context)
4080
                x = torch.rand(3, 4, device="cuda")
4081
                del x
4082
                torch.cuda.memory.empty_cache()
4083

4084
                ss = json.dumps(torch.cuda.memory._snapshot())
4085
                self.assertTrue(("empty_cache" in ss) == (context == "all"))
4086
            finally:
4087
                torch.cuda.memory._record_memory_history(None)
4088

4089
    @unittest.skipIf(
4090
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4091
    )
4092
    def test_memory_snapshot_script(self):
4093
        try:
4094
            torch.cuda.memory.empty_cache()
4095
            torch.cuda.memory._record_memory_history("state", stacks="python")
4096

4097
            @torch.jit.script
4098
            def foo():
4099
                return torch.rand(311, 411, device="cuda")
4100

4101
            x = foo()
4102

4103
            ss = torch.cuda.memory._snapshot()["segments"]
4104
            found_it = False
4105
            for seg in ss:
4106
                for b in seg["blocks"]:
4107
                    if b["requested_size"] == 311 * 411 * 4:
4108
                        self.assertTrue(b["frames"][0]["name"] == "foo")
4109
                        found_it = True
4110
            self.assertTrue(found_it)
4111

4112
        finally:
4113
            torch.cuda.memory._record_memory_history(None)
4114

4115
    def test_max_split_expandable(self):
4116
        torch.cuda.memory.empty_cache()
4117
        mb = 1024 * 1024
4118
        _, all_memory = torch.cuda.memory.mem_get_info()
4119
        total_allowed = 120 * mb
4120
        fraction_allowed = total_allowed / all_memory
4121
        assert int(fraction_allowed * all_memory) == total_allowed
4122
        torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
4123

4124
        def alloc(n):
4125
            return torch.ones(n * mb, dtype=torch.int8, device="cuda")
4126

4127
        torch.cuda.memory._set_allocator_settings(
4128
            "expandable_segments:False,max_split_size_mb:40"
4129
        )
4130
        a = alloc(40)
4131
        torch.cuda.memory._set_allocator_settings(
4132
            "expandable_segments:True,max_split_size_mb:40"
4133
        )
4134
        b = alloc(40)
4135
        torch.cuda.memory._set_allocator_settings(
4136
            "expandable_segments:False,max_split_size_mb:40"
4137
        )
4138
        c = alloc(40)
4139
        with self.assertRaises(torch.OutOfMemoryError):
4140
            alloc(40)
4141
        del a, b, c
4142
        # force release_cached_blocks to run with some expandable segments in the free list
4143
        alloc(120)
4144

4145
    def test_garbage_collect_expandable(self):
4146
        torch.cuda.memory.empty_cache()
4147
        mb = 1024 * 1024
4148
        _, all_memory = torch.cuda.memory.mem_get_info()
4149
        total_allowed = 120 * mb
4150
        fraction_allowed = total_allowed / all_memory
4151
        assert int(fraction_allowed * all_memory) == total_allowed
4152
        torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
4153

4154
        def alloc(n):
4155
            return torch.ones(n * mb, dtype=torch.int8, device="cuda")
4156

4157
        torch.cuda.memory._set_allocator_settings(
4158
            "expandable_segments:False,garbage_collection_threshold:0.5"
4159
        )
4160
        a = alloc(40)
4161
        torch.cuda.memory._set_allocator_settings(
4162
            "expandable_segments:True,garbage_collection_threshold:0.5"
4163
        )
4164
        b = alloc(40)
4165
        del a, b
4166
        # causes GC to run. The expandable segment block will be split
4167
        # so GC would not attempt to free it anyway, but this at least makes sure
4168
        # expandable_segment blocks can be in the free list when this is called.
4169
        alloc(80)
4170

4171
    def test_allocator_settings(self):
4172
        def power2_div(size, div_factor):
4173
            pow2 = 1
4174
            while pow2 < size:
4175
                pow2 = pow2 * 2
4176
            if pow2 == size:
4177
                return pow2
4178
            step = pow2 / 2 / div_factor
4179
            ret = pow2 / 2
4180
            while ret < size:
4181
                ret = ret + step
4182
            return ret
4183

4184
        torch.cuda.memory.empty_cache()
4185
        key_allocated = (
4186
            "active_bytes.all.allocated"
4187
            if not TEST_CUDAMALLOCASYNC
4188
            else "allocated_bytes.all.current"
4189
        )
4190
        key_requested = "requested_bytes.all.allocated"
4191

4192
        nelems = 21 * 1024 * 1024
4193
        nbytes = 4 * nelems  # floats are 4 bytes
4194

4195
        nelems_big = 100 * 1024 * 1024
4196
        nbytes_big = 4 * nelems_big  # floats are 4 bytes
4197

4198
        start_mem = torch.cuda.memory_stats()[key_allocated]
4199
        torch.cuda.memory._set_allocator_settings("")
4200
        x = torch.rand(nelems, device="cuda")
4201

4202
        # test roundup_power2_divisions single value syntax
4203
        reg_mem = torch.cuda.memory_stats()[key_allocated]
4204
        start_requested = torch.cuda.memory_stats()[key_requested]
4205
        torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4")
4206
        y = torch.rand(nelems, device="cuda")
4207

4208
        pow2_div4_mem = torch.cuda.memory_stats()[key_allocated]
4209
        current_requested = torch.cuda.memory_stats()[key_requested]
4210

4211
        self.assertTrue(reg_mem - start_mem == nbytes)
4212
        if not TEST_CUDAMALLOCASYNC:
4213
            # not supported with the cudaMallocAsync backend
4214
            self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4))
4215
            self.assertTrue(current_requested - start_requested == nbytes)
4216

4217
        torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5")
4218
        torch.cuda.memory._set_allocator_settings(
4219
            "garbage_collection_threshold:0.5,max_split_size_mb:40"
4220
        )
4221

4222
        # should have reset the power2 divisions now
4223
        torch.cuda.memory.empty_cache()
4224
        start_mem = torch.cuda.memory_stats()[key_allocated]
4225
        z = torch.rand(nelems, device="cuda")
4226
        reg_mem = torch.cuda.memory_stats()[key_allocated]
4227
        self.assertTrue(reg_mem - start_mem == nbytes)
4228

4229
        # roundup_power2_divisions knob array syntax
4230
        torch.cuda.memory.empty_cache()
4231
        torch.cuda.memory._set_allocator_settings(
4232
            "garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]"
4233
        )
4234
        start_mem = torch.cuda.memory_stats()[key_allocated]
4235
        w = torch.rand(nelems, device="cuda")
4236

4237
        pow2_div8_mem = torch.cuda.memory_stats()[key_allocated]
4238
        if not TEST_CUDAMALLOCASYNC:
4239
            # not supported with the cudaMallocAsync backend
4240
            self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8))
4241

4242
        torch.cuda.memory.empty_cache()
4243
        start_mem = torch.cuda.memory_stats()[key_allocated]
4244
        v = torch.rand(nelems_big, device="cuda")
4245

4246
        pow2_div2_mem = torch.cuda.memory_stats()[key_allocated]
4247
        if not TEST_CUDAMALLOCASYNC:
4248
            # not supported with the cudaMallocAsync backend
4249
            self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2))
4250

4251
        torch.cuda.memory.empty_cache()
4252
        torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True")
4253
        start_mem = torch.cuda.memory_stats()[key_allocated]
4254
        w = torch.rand(nelems, device="cuda")
4255
        reg_mem = torch.cuda.memory_stats()[key_allocated]
4256
        self.assertTrue(reg_mem - start_mem == nbytes)
4257

4258
        with self.assertRaises(RuntimeError):
4259
            torch.cuda.memory._set_allocator_settings("foo:1,bar:2")
4260

4261
        with self.assertRaises(RuntimeError):
4262
            torch.cuda.memory._set_allocator_settings(
4263
                "garbage_collection_threshold:1.2"
4264
            )
4265

4266
        with self.assertRaises(RuntimeError):
4267
            torch.cuda.memory._set_allocator_settings("max_split_size_mb:2")
4268

4269
        with self.assertRaises(RuntimeError):
4270
            torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none")
4271

4272
        with self.assertRaises(RuntimeError):
4273
            torch.cuda.memory._set_allocator_settings(
4274
                "pinned_use_cuda_host_register:none"
4275
            )
4276

4277
        with self.assertRaises(RuntimeError):
4278
            torch.cuda.memory._set_allocator_settings(
4279
                "pinned_num_register_threads:none"
4280
            )
4281

4282
        with self.assertRaises(RuntimeError):
4283
            torch.cuda.memory._set_allocator_settings(
4284
                "pinned_num_register_threads:1024"
4285
            )
4286

4287
    @parametrize("max_split_size_mb_setting", [False, True])
4288
    def test_raises_oom(self, max_split_size_mb_setting):
4289
        if max_split_size_mb_setting:
4290
            # CudaCachingAllocator does early return when searching available blocks
4291
            # if max_split_size_mb is not set
4292
            # Setting this triggers more parts of the code
4293
            torch.cuda.memory._set_allocator_settings("max_split_size_mb:1024")
4294
            torch.cuda.memory.empty_cache()
4295
        with self.assertRaises(torch.cuda.OutOfMemoryError):
4296
            torch.empty(1024 * 1024 * 1024 * 1024, device="cuda")
4297

4298
    @unittest.skipIf(
4299
        not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux"
4300
    )
4301
    @unittest.skipIf(
4302
        TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4303
    )
4304
    def test_cpp_memory_snapshot_pickle(self):
4305
        from torch.utils.cpp_extension import load_inline
4306

4307
        source = """
4308
        #include <torch/csrc/cuda/memory_snapshot.h>
4309
        py::object do_snapshot() {
4310
            std::string data = torch::cuda::_memory_snapshot_pickled();
4311
            return py::bytes(data);
4312
        }
4313
        void record(bool e, bool ctx) {
4314
            torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx);
4315
        }
4316
        """
4317
        m = load_inline(
4318
            name="snapshot", cpp_sources=[source], functions=["do_snapshot", "record"]
4319
        )
4320
        for ctx in (False, True):
4321
            try:
4322
                m.record(True, ctx)
4323

4324
                @torch.jit.script
4325
                def the_script_fn():
4326
                    return torch.rand(311, 411, device="cuda")
4327

4328
                def run():
4329
                    t = the_script_fn()
4330
                    return pickle.loads(m.do_snapshot())
4331

4332
                mem = run()
4333
                found = False
4334
                for s in mem["segments"]:
4335
                    for b in s["blocks"]:
4336
                        if b["state"] == "active_allocated":
4337
                            if b["requested_size"] == 311 * 411 * 4:
4338
                                if ctx:
4339
                                    frame_text = str(b["frames"])
4340
                                    # C++ frame
4341
                                    self.assertTrue("::rand" in frame_text)
4342
                                    # script frame
4343
                                    self.assertTrue("the_script_fn" in frame_text)
4344
                                    # python frame
4345
                                    self.assertTrue("case.py" in frame_text)
4346
                                found = True
4347
                last_action = mem["device_traces"][0][-1]
4348
                self.assertTrue(last_action["action"] == "alloc")
4349
                self.assertTrue(last_action["size"] == 311 * 411 * 4)
4350
                self.assertTrue(found)
4351
            finally:
4352
                m.record(False, False)
4353

4354
    @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
4355
    def test_notifies_oom(self):
4356
        x = False
4357

4358
        def cb(device, alloc, device_alloc, device_free):
4359
            nonlocal x
4360
            x = True
4361

4362
        torch._C._cuda_attach_out_of_memory_observer(cb)
4363
        with self.assertRaises(torch.cuda.OutOfMemoryError):
4364
            torch.empty(1024 * 1024 * 1024 * 1024, device="cuda")
4365
        self.assertTrue(x)
4366

4367
    def test_allocator_fuzz(self):
4368
        # fuzz
4369
        state = random.getstate()
4370
        random.seed(123)
4371
        N = 10000
4372
        try:
4373
            mem = []
4374
            total = 0
4375
            c = 0
4376

4377
            def alloc():
4378
                nonlocal total, c
4379
                b = random.randrange(2 * 1024 * 1024 // 4, 20 * 1024 * 1024 // 4)
4380
                mem.append((c, torch.full((b,), c, dtype=torch.int32, device="cuda")))
4381
                c += 1
4382
                total += b
4383

4384
            def free():
4385
                nonlocal total
4386
                idx = random.randrange(0, len(mem))
4387
                v, x = mem.pop(idx)
4388
                assert torch.all(v == x)
4389
                total -= x.numel()
4390

4391
            choices = [alloc, free, torch.cuda.memory.empty_cache]
4392
            for i in range(N):
4393
                while total >= 1024 * 1024 * 1024 / (4 * 10):
4394
                    free()
4395
                (action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1])
4396
                action()
4397
        finally:
4398
            random.setstate(state)
4399

4400
    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
4401
    def test_nvml_get_handler(self):
4402
        if not torch.version.hip:
4403
            self.assertTrue(torch.cuda._get_pynvml_handler() is not None)
4404
        else:
4405
            self.assertTrue(torch.cuda._get_amdsmi_handler() is not None)
4406

4407
    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
4408
    def test_temperature(self):
4409
        self.assertTrue(0 <= torch.cuda.temperature() <= 150)
4410

4411
    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
4412
    def test_power_draw(self):
4413
        self.assertTrue(torch.cuda.power_draw() >= 0)
4414

4415
    @unittest.skipIf(TEST_PYNVML, "pynvml is not available")
4416
    def test_clock_speed(self):
4417
        self.assertTrue(torch.cuda.clock_rate() >= 0)
4418

4419

4420
MIN_BLOCK_SIZE = 512
4421
SMALL_SIZE = 1048576
4422
SMALL_BUFFER = 2097152
4423
LARGE_BUFFER = 20971520
4424

4425

4426
def get_cudagraph_segments(pool_id):
4427
    segments = torch.cuda.memory_snapshot()
4428
    return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
4429

4430

4431
def get_all_cudagraph_segments():
4432
    segments = torch.cuda.memory_snapshot()
4433
    return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
4434

4435

4436
def cudagraphify(fn, inputs, pool=None):
4437
    if not TEST_CUDA_GRAPH:
4438
        raise unittest.SkipTest("cuda graph test is skipped")
4439

4440
    torch.cuda.synchronize()
4441
    stream = torch.cuda.Stream()
4442
    stream.wait_stream(torch.cuda.current_stream())
4443
    with torch.cuda.stream(stream):
4444
        fn(*inputs)
4445
    stream.synchronize()
4446
    torch.cuda.current_stream().wait_stream(stream)
4447
    torch.cuda.synchronize()
4448

4449
    graph = torch.cuda.CUDAGraph()
4450
    with torch.cuda.graph(graph, stream=stream, pool=pool):
4451
        static_outputs = fn(*inputs)
4452

4453
    return graph, static_outputs
4454

4455

4456
def int8_cuda(size):
4457
    return torch.ones([size], device="cuda", dtype=torch.uint8)
4458

4459

4460
def live_blocks(pool_id):
4461
    blocks = 0
4462
    seg = get_cudagraph_segments(pool_id)
4463
    for segment in get_cudagraph_segments(pool_id):
4464
        for block in segment["blocks"]:
4465
            blocks += block["state"] == "active_allocated"
4466
    return blocks
4467

4468

4469
def tensor_metadata(x):
4470
    return {
4471
        "nbytes": x.untyped_storage().nbytes(),
4472
        "data_ptr": x.untyped_storage().data_ptr(),
4473
        "size": x.shape,
4474
        "stride": x.stride(),
4475
        "dtype": x.dtype,
4476
        "device": x.device,
4477
        "storage_offset": x.storage_offset(),
4478
    }
4479

4480

4481
def reconstruct_from_tensor_metadata(metadata):
4482
    s = torch._C._construct_storage_from_data_pointer(
4483
        metadata["data_ptr"], metadata["device"], metadata["nbytes"]
4484
    )
4485
    t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"])
4486
    t.set_(
4487
        source=s,
4488
        storage_offset=metadata["storage_offset"],
4489
        size=metadata["size"],
4490
        stride=metadata["stride"],
4491
    )
4492
    return t
4493

4494

4495
@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
4496
@torch.testing._internal.common_utils.markDynamoStrictTest
4497
class TestBlockStateAbsorption(TestCase):
4498
    @property
4499
    def expandable_segments(self):
4500
        return EXPANDABLE_SEGMENTS
4501

4502
    def checkCheckpointedBlock(self, before_block, after_block):
4503
        for field in ("size", "state"):
4504
            self.assertEqual(before_block[field], after_block[field])
4505

4506
    def checkCheckpointedState(self, before_segments, after_segments):
4507
        # after may contain additional segments, but all of the segments in before
4508
        # should be exactly equivalent to after
4509
        after_ptr_to_segment = {
4510
            segment["address"]: segment for segment in after_segments
4511
        }
4512

4513
        for before_segment in before_segments:
4514
            self.assertTrue(before_segment["address"] in after_ptr_to_segment)
4515
            after_segment = after_ptr_to_segment[before_segment["address"]]
4516

4517
            for field in (
4518
                "device",
4519
                "total_size",
4520
                "allocated_size",
4521
                "active_size",
4522
                "segment_type",
4523
                "segment_pool_id",
4524
            ):
4525
                self.assertEqual(before_segment[field], after_segment[field])
4526

4527
            self.assertEqual(
4528
                len(before_segment["blocks"]), len(after_segment["blocks"])
4529
            )
4530
            for before_block, after_block in zip(
4531
                before_segment["blocks"], after_segment["blocks"]
4532
            ):
4533
                self.checkCheckpointedBlock(before_block, after_block)
4534

4535
    @staticmethod
4536
    def setCheckpointPoolState(
4537
        device, state, stale_storages_ptr, storages_deleters=None
4538
    ):
4539
        stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr]
4540
        storages_deleters = (
4541
            []
4542
            if not storages_deleters
4543
            else [t.untyped_storage()._cdata for t in storages_deleters]
4544
        )
4545
        torch._C._cuda_setCheckpointPoolState(
4546
            device, state, stale_storages_ptr, storages_deleters
4547
        )
4548

4549
    def checkFunction(self, fn, inputs, pool=None):
4550
        graph, outputs = cudagraphify(fn, inputs, pool=pool)
4551

4552
        pool_id = graph.pool()
4553
        device = outputs[0].device.index
4554

4555
        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4556

4557
        state = torch._C._cuda_getCheckpointState(device, pool_id)
4558
        self.setCheckpointPoolState(device, state, [], [])
4559

4560
        self.checkCheckpointedState(
4561
            segments_before_checkpoint, get_cudagraph_segments(pool_id)
4562
        )
4563

4564
    def setUp(self):
4565
        super().setUp()
4566
        self.segment_length = len(get_all_cudagraph_segments())
4567

4568
    def tearDown(self):
4569
        torch.cuda.synchronize()
4570
        gc.collect()
4571
        torch.cuda.empty_cache()
4572

4573
        self.assertEqual(len(get_all_cudagraph_segments()), self.segment_length)
4574

4575
        super().tearDown()
4576

4577
    def test_simple(self):
4578
        def foo():
4579
            x = torch.zeros([SMALL_SIZE * 8], device="cuda", dtype=torch.uint8)
4580
            x = x + x
4581
            x1 = int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE)
4582
            y = int8_cuda(SMALL_SIZE) + x1
4583
            z = int8_cuda(SMALL_SIZE)
4584
            return x, y, z
4585

4586
        self.checkFunction(foo, [])
4587

4588
    def test_allocated_in_middle_of_segment(self):
4589
        def foo():
4590
            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4591
            return small_buffers[5].add_(2)
4592

4593
        self.checkFunction(foo, [])
4594

4595
    def test_multiple_middle_allocations(self):
4596
        def foo():
4597
            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4598
            return small_buffers[5], small_buffers[8]
4599

4600
        self.checkFunction(foo, [])
4601

4602
    def test_middle_allocations_contiguous(self):
4603
        def foo():
4604
            small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4605
            return small_buffers[5], small_buffers[6]
4606

4607
        self.checkFunction(foo, [])
4608

4609
    def test_additional_free_following_checkpoint(self):
4610
        def foo():
4611
            return (int8_cuda(MIN_BLOCK_SIZE),)
4612

4613
        def foo2():
4614
            return (int8_cuda(MIN_BLOCK_SIZE),)
4615

4616
        graph, outputs = cudagraphify(foo, [])
4617
        pool_id = graph.pool()
4618

4619
        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4620

4621
        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4622

4623
        graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
4624

4625
        self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, [])
4626

4627
        del outputs2
4628

4629
        self.checkCheckpointedState(
4630
            segments_before_checkpoint, get_cudagraph_segments(pool_id)
4631
        )
4632

4633
    # TODO: re-enable
4634
    # def test_additional_free_error(self):
4635
    #     def foo():
4636
    #         return int8_cuda(MIN_BLOCK_SIZE),
4637

4638
    #     def foo2():
4639
    #         return int8_cuda(MIN_BLOCK_SIZE),
4640

4641
    #     graph, outputs = cudagraphify(foo, [])
4642
    #     pool_id = graph.pool()
4643

4644
    #     segments_before_checkpoint = get_cudagraph_segments(pool_id)
4645

4646
    #     state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4647

4648
    # graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
4649
    # with self.assertRaisesRegex(Exception, "being manually freed must be passed"):
4650
    #     self.setCheckpointPoolState(outputs[0].device.index, state, [], [])
4651

4652
    def test_tensor_dies_after_checkpoint(self):
4653
        def foo():
4654
            return int8_cuda(MIN_BLOCK_SIZE), int8_cuda(MIN_BLOCK_SIZE)
4655

4656
        graph, outputs = cudagraphify(foo, [])
4657
        pool_id = graph.pool()
4658
        device = outputs[0].device.index
4659

4660
        segments_before_checkpoint = get_cudagraph_segments(pool_id)
4661
        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4662

4663
        output_data_ptrs = [output.data_ptr() for output in outputs]
4664

4665
        del outputs
4666

4667
        self.setCheckpointPoolState(device, state, [], [])
4668

4669
        self.assertEqual(live_blocks(pool_id), 2)
4670
        torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0])
4671
        self.assertEqual(live_blocks(pool_id), 1)
4672
        torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[1])
4673
        self.assertEqual(live_blocks(pool_id), 0)
4674

4675
    def test_assigning_back_deleter_fns_to_tensor(self):
4676
        def foo(x):
4677
            return (
4678
                int8_cuda(SMALL_BUFFER) + x,
4679
                int8_cuda(SMALL_BUFFER) + x,
4680
                int8_cuda(LARGE_BUFFER) + x,
4681
            )
4682

4683
        inp = torch.tensor([1], device="cuda")
4684
        graph, outputs = cudagraphify(foo, [inp])
4685
        pool_id = graph.pool()
4686
        graph.replay()
4687

4688
        device = outputs[0].device.index
4689

4690
        for i in range(len(outputs)):
4691
            self.assertTrue(outputs[i].mean(dtype=torch.float) == 2)
4692

4693
        state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4694

4695
        output_ptrs = [output.untyped_storage().data_ptr() for output in outputs]
4696
        ten_metadata = [tensor_metadata(t) for t in outputs]
4697

4698
        self.assertEqual(live_blocks(pool_id), 3)
4699

4700
        del outputs
4701

4702
        self.assertEqual(live_blocks(pool_id), 0)
4703

4704
        reconstructed_tensors = [
4705
            reconstruct_from_tensor_metadata(metadata) for metadata in ten_metadata
4706
        ]
4707

4708
        for i in range(len(reconstructed_tensors)):
4709
            self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2)
4710

4711
        inp.add_(1)
4712
        graph.replay()
4713

4714
        for i in range(len(reconstructed_tensors)):
4715
            self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3)
4716

4717
        self.setCheckpointPoolState(
4718
            device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]]
4719
        )
4720

4721
        self.assertEqual(live_blocks(pool_id), 3)
4722

4723
        reconstructed_tensors[0] = None
4724
        self.assertEqual(live_blocks(pool_id), 2)
4725

4726
        reconstructed_tensors[1] = None
4727
        self.assertEqual(live_blocks(pool_id), 1)
4728

4729
        # should not change, we did not pass it in to swap data ptrs
4730
        reconstructed_tensors[2] = None
4731
        self.assertEqual(live_blocks(pool_id), 1)
4732

4733
        torch._C._cuda_cudaCachingAllocator_raw_delete(output_ptrs[2])
4734

4735
        self.assertEqual(live_blocks(pool_id), 0)
4736

4737
    @skipIfNoTorchVision
4738
    def test_resnet(self):
4739
        import torchvision
4740

4741
        m = torchvision.models.resnet50()
4742
        m.eval()
4743
        m = m.cuda()
4744

4745
        inp = torch.rand([1, 3, 255, 255], device="cuda")
4746
        self.checkFunction(m, [inp])
4747

4748
    def test_check_pool_live_allocations(self):
4749
        def foo():
4750
            return torch.ones([4], device="cuda")
4751

4752
        pool = torch.cuda.graph_pool_handle()
4753
        graph, outputs = cudagraphify(foo, [], pool=pool)
4754

4755
        index = outputs[0].device.index
4756

4757
        def check(live_dps):
4758
            return torch._C._cuda_checkPoolLiveAllocations(index, pool, live_dps)
4759

4760
        self.assertTrue(check({outputs[0].data_ptr()}))
4761

4762
        self.assertFalse(check({outputs[0].data_ptr(), 0}))
4763
        self.assertFalse(check(set()))
4764

4765
        del outputs
4766
        self.assertTrue(check(set()))
4767

4768
    def test_allocate_in_thread_to_pool(self):
4769
        def foo():
4770
            return torch.rand([4], device="cuda")
4771

4772
        pool = torch.cuda.graph_pool_handle()
4773
        graph, outputs = cudagraphify(foo, [], pool=pool)
4774
        device = outputs[0].device.index
4775
        del outputs
4776

4777
        @contextlib.contextmanager
4778
        def _use_cuda_memory_pool_manager(device, mem_pool):
4779
            """
4780
            Context manager to use cuda graph pool for new allocations. If you use this manager
4781
            all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
4782
            existing_graph should already have been used in a capture, and the mem_pool must already exist.
4783
            """
4784
            torch.cuda.synchronize()
4785
            stream = torch.cuda.Stream()
4786
            stream.wait_stream(torch.cuda.current_stream())
4787
            stream_context = torch.cuda.stream(stream)
4788
            stream_context.__enter__()
4789
            torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
4790
            try:
4791
                yield
4792
            finally:
4793
                torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
4794
                torch._C._cuda_releasePool(device, mem_pool)
4795
                stream_context.__exit__(None, None, None)
4796

4797
        segments = get_cudagraph_segments(pool)
4798
        self.assertEqual(len(get_cudagraph_segments(pool)), 1)
4799

4800
        def use_pool():
4801
            def alloc_three():
4802
                a = int8_cuda(LARGE_BUFFER)
4803
                b = int8_cuda(LARGE_BUFFER)
4804
                c = a + b
4805

4806
            with _use_cuda_memory_pool_manager(device, pool):
4807
                # three allocations
4808
                for _ in range(10):
4809
                    alloc_three()
4810

4811
            # three more allocations not in pool
4812
            alloc_three()
4813

4814
        def no_pool():
4815
            # two allocations
4816
            for _ in range(10):
4817
                a = int8_cuda(LARGE_BUFFER)
4818
                b = int8_cuda(LARGE_BUFFER)
4819
                del a, b
4820

4821
        graph_thread = threading.Thread(target=use_pool)
4822
        no_graph_thread = threading.Thread(target=no_pool)
4823
        graph_thread.start()
4824
        no_graph_thread.start()
4825

4826
        graph_thread.join()
4827
        no_graph_thread.join()
4828

4829
        self.assertEqual(
4830
            len(get_cudagraph_segments(pool)), 2 if self.expandable_segments else 4
4831
        )
4832

4833
        del graph
4834

4835
        torch.cuda.synchronize()
4836
        gc.collect()
4837
        torch.cuda.empty_cache()
4838

4839
        self.assertEqual(len(get_cudagraph_segments(pool)), 0)
4840

4841
    def test_no_triton_on_import(self):
4842
        """Test that Trition is not imported on first GPU use"""
4843
        script = "import sys; import torch; torch.rand(2, device='cuda'); print('triton' in sys.modules)"
4844

4845
        rc = (
4846
            subprocess.check_output(
4847
                [sys.executable, "-c", script],
4848
                # On Windows, opening the subprocess with the default CWD makes `import torch`
4849
                # fail, so just set CWD to this script's directory
4850
                cwd=os.path.dirname(os.path.realpath(__file__)),
4851
            )
4852
            .strip()
4853
            .decode("ascii")
4854
        )
4855
        self.assertEqual(rc, "False", "Triton was imported when importing torch!")
4856

4857

4858
class TestMemPool(TestCase):
4859
    def test_mempool_id(self):
4860
        pool1 = torch.cuda.graph_pool_handle()
4861
        pool2 = torch.cuda.MemPool().id
4862

4863
        # first value of id in a user created pool is always zero
4864
        self.assertEqual(pool1[0] == 0, pool2[0] == 0)
4865

4866
        # each call to torch.cuda.graph_pool_handle() or torch.cuda.MemPool()
4867
        # increments the id
4868
        self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
4869

4870
    def test_mempool_with_allocator(self):
4871
        pool = torch.cuda.MemPool()
4872

4873
        # MemPool doesn't have an allocator by default
4874
        self.assertEqual(pool.allocator, None)
4875

4876
        from torch.utils.cpp_extension import load_inline
4877

4878
        dummy_allocator_source = """
4879
        #include <torch/extension.h>
4880
        #include <ATen/cuda/Exceptions.h>
4881
        #include <cuda_runtime_api.h>
4882

4883
        extern "C" {
4884
          C10_EXPORT int called_dummy_alloc = 0;
4885
          C10_EXPORT int called_dummy_free = 0;
4886

4887
          // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
4888
          C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
4889
            called_dummy_alloc = 123;
4890
            void* ptr;
4891
            C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
4892
            return ptr;
4893
          }
4894

4895
          C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
4896
            called_dummy_free = 321;
4897
            C10_CUDA_CHECK(cudaFree(ptr));
4898
          }
4899
        }
4900
        """
4901
        dummy_allocator_libname = "dummy_allocator"
4902
        dummy_allocator = load_inline(
4903
            name=dummy_allocator_libname,
4904
            cpp_sources=dummy_allocator_source,
4905
            is_python_module=False,
4906
            keep_intermediates=False,
4907
            verbose=True,
4908
            with_cuda=True,
4909
        )
4910
        allocator = torch.cuda.memory.CUDAPluggableAllocator(
4911
            dummy_allocator,
4912
            "dummy_alloc",
4913
            "dummy_free",
4914
        )
4915
        pool = torch.cuda.MemPool(allocator.allocator())
4916

4917
        # pool should point to the same allocator as the one passed into it
4918
        self.assertEqual(allocator.allocator(), pool.allocator)
4919

4920
        # no allocations happened yet, so called_dummy_alloc should be 0
4921
        alloc_lib = ctypes.CDLL(dummy_allocator)
4922
        called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc")
4923
        self.assertEqual(called_dummy_alloc.value, 0)
4924

4925
        with torch.cuda.use_mem_pool(pool):
4926
            out = torch.randn(1, device="cuda")
4927

4928
        # called_dummy_alloc should be 123 if dummy_alloc was used to allocate
4929
        # out tensor
4930
        self.assertEqual(called_dummy_alloc.value, 123)
4931

4932
    def test_mempool_context(self):
4933
        active_pool = torch.cuda.MemPoolContext.active_pool()
4934

4935
        # there is no active pool if none was made active
4936
        self.assertEqual(active_pool, None)
4937

4938
        pool = torch.cuda.MemPool()
4939
        ctx = torch.cuda.MemPoolContext(pool)
4940
        active_pool = torch.cuda.MemPoolContext.active_pool()
4941

4942
        # pool was made active
4943
        self.assertEqual(active_pool, pool)
4944

4945
        del ctx
4946
        active_pool = torch.cuda.MemPoolContext.active_pool()
4947

4948
        # ctx was deleted, so active pool is the previous one
4949
        self.assertEqual(active_pool, None)
4950

4951
    def test_mempool_multithread(self):
4952
        pool_ids = []
4953
        active_pool_ids = []
4954

4955
        def create_mempool_and_make_active():
4956
            pool = torch.cuda.MemPool()
4957
            pool_ids.extend([pool.id])
4958

4959
            ctx = torch.cuda.MemPoolContext(pool)
4960
            active_pool = torch.cuda.MemPoolContext.active_pool()
4961
            active_pool_ids.extend([active_pool.id])
4962
            del ctx
4963

4964
        num_threads = 4
4965
        threads = [
4966
            threading.Thread(target=create_mempool_and_make_active)
4967
            for t in range(num_threads)
4968
        ]
4969
        for thread in threads:
4970
            thread.start()
4971
        for thread in threads:
4972
            thread.join()
4973

4974
        # each thread should create a unique mempool, since
4975
        # mempool id creation is atomic
4976
        self.assertEqual(len(set(pool_ids)), 4)
4977

4978
        # each thread should have different active mempool, since
4979
        # the pointer to the mempool is thread local
4980
        self.assertEqual(len(set(active_pool_ids)), 4)
4981

4982

4983
@torch.testing._internal.common_utils.markDynamoStrictTest
4984
class TestCudaOptims(TestCase):
4985
    # These tests will be instantiate with instantiate_device_type_tests
4986
    # to apply the new OptimizerInfo structure.
4987

4988
    @onlyCUDA
4989
    @unittest.skipIf(
4990
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs"
4991
    )
4992
    @optims(
4993
        [optim for optim in optim_db if optim.has_capturable_arg],
4994
        dtypes=[torch.float32],
4995
    )
4996
    def test_graph_optims(self, device, dtype, optim_info):
4997
        optim_cls = optim_info.optim_cls
4998
        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
4999
            device, dtype, optim_info, skip=("differentiable",)
5000
        )
5001

5002
        steps_warmup = 3
5003
        steps_train = 2
5004

5005
        for optim_input in all_optim_inputs:
5006
            kwargs = optim_input.kwargs
5007

5008
            # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam
5009
            # and torch.optim.adamw
5010
            kwargs["lr"] = 0.1
5011

5012
            for actually_do_graphs in (True, False):
5013
                params = [
5014
                    torch.randn((i + 5, i + 5), device=device) for i in range(2)
5015
                ] + [torch.randn((), device=device)]
5016
                params_control = [p.clone().requires_grad_() for p in params]
5017
                params_graphed = [p.clone().requires_grad_() for p in params]
5018

5019
                grads = [
5020
                    [torch.randn_like(p) for p in params]
5021
                    for _ in range(steps_warmup + steps_train)
5022
                ]
5023

5024
                # Control (capturable=False)
5025
                kwargs["capturable"] = False
5026

5027
                opt = optim_cls(params_control, **kwargs)
5028
                for i in range(steps_warmup + steps_train):
5029
                    for j, p in enumerate(params_control):
5030
                        p.grad = grads[i][j]
5031
                    opt.step()
5032

5033
                # capturable=True
5034
                kwargs["capturable"] = True
5035
                opt = optim_cls(params_graphed, **kwargs)
5036

5037
                for i in range(steps_warmup):
5038
                    for j, p in enumerate(params_graphed):
5039
                        p.grad = grads[i][j]
5040
                    opt.step()
5041

5042
                if actually_do_graphs:
5043
                    g = torch.cuda.CUDAGraph()
5044
                    with torch.cuda.graph(g):
5045
                        opt.step()
5046

5047
                for i in range(steps_train):
5048
                    if actually_do_graphs:
5049
                        for j, p in enumerate(params_graphed):
5050
                            p.grad.copy_(grads[i + steps_warmup][j])
5051
                        g.replay()
5052
                    else:
5053
                        # Passing capturable=True to the constructor and running without graphs should still be
5054
                        # numerically correct, even if it's not ideal for performance.
5055
                        for j, p in enumerate(params_graphed):
5056
                            p.grad = grads[i + steps_warmup][j]
5057
                        opt.step()
5058

5059
                for p_control, p_graphed in zip(params_control, params_graphed):
5060
                    self.assertEqual(p_control, p_graphed)
5061

5062
    @onlyCUDA
5063
    @unittest.skipIf(
5064
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
5065
    )
5066
    @optims(
5067
        [
5068
            optim
5069
            for optim in optim_db
5070
            if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on
5071
        ],
5072
        dtypes=[torch.float32],
5073
    )
5074
    def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
5075
        optim_cls = optim_info.optim_cls
5076

5077
        steps_warmup = 3
5078
        steps_train = 2
5079

5080
        optim_inputs = optim_info.optim_inputs_func(device=device)
5081

5082
        for optim_input in optim_inputs:
5083
            kwargs = optim_input.kwargs
5084
            kwargs["fused"] = True
5085

5086
            for actually_do_graphs in (
5087
                (True, False) if optim_info.has_capturable_arg else (True,)
5088
            ):
5089
                params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
5090
                params_control = [p.clone().requires_grad_() for p in params]
5091
                params_graphed = [p.clone().requires_grad_() for p in params]
5092

5093
                # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
5094
                grads = [
5095
                    [torch.randn_like(p) for p in params]
5096
                    for _ in range(steps_warmup + steps_train)
5097
                ]
5098
                with torch.no_grad():
5099
                    grads_control = [[g.clone() for g in gs] for gs in grads]
5100
                    grads_graphed = [[g.clone() for g in gs] for gs in grads]
5101

5102
                # Gradient Scaler
5103
                scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
5104
                with torch.no_grad():
5105
                    scaler_for_control._lazy_init_scale_growth_tracker(device)
5106

5107
                scaler_for_graphed = torch.cuda.amp.GradScaler()
5108
                scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
5109
                with torch.no_grad():
5110
                    scaler_for_graphed._lazy_init_scale_growth_tracker(device)
5111

5112
                # Control (capturable=False)
5113
                if optim_info.has_capturable_arg:
5114
                    kwargs["capturable"] = False
5115
                opt = optim_cls(params_control, **kwargs)
5116

5117
                for i in range(steps_warmup + steps_train):
5118
                    for j, p in enumerate(params_control):
5119
                        p.grad = grads_control[i][j]
5120
                    scaler_for_control.step(opt)
5121
                    scaler_for_control.update()
5122

5123
                # capturable=True
5124
                if optim_info.has_capturable_arg:
5125
                    kwargs["capturable"] = True
5126
                opt = optim_cls(params_graphed, **kwargs)
5127

5128
                for i in range(steps_warmup):
5129
                    for j, p in enumerate(params_graphed):
5130
                        p.grad = grads_graphed[i][j]
5131
                    scaler_for_graphed.step(opt)
5132
                    scaler_for_graphed.update()
5133

5134
                if actually_do_graphs:
5135
                    g = torch.cuda.CUDAGraph()
5136
                    with torch.cuda.graph(g):
5137
                        scaler_for_graphed.step(opt)
5138
                        scaler_for_graphed.update()
5139

5140
                for i in range(steps_train):
5141
                    if actually_do_graphs:
5142
                        for j, p in enumerate(params_graphed):
5143
                            p.grad.copy_(grads_graphed[i + steps_warmup][j])
5144
                        g.replay()
5145
                    else:
5146
                        # Passing capturable=True to the constructor and running without graphs should still be
5147
                        # numerically correct, even if it's not ideal for performance.
5148
                        for j, p in enumerate(params_graphed):
5149
                            p.grad = grads_graphed[i + steps_warmup][j]
5150
                        scaler_for_graphed.step(opt)
5151
                        scaler_for_graphed.update()
5152

5153
                for p_control, p_graphed in zip(params_control, params_graphed):
5154
                    self.assertEqual(p_control, p_graphed)
5155

5156
    @onlyNativeDeviceTypes
5157
    @optims(
5158
        [optim for optim in optim_db if "fused" in optim.supported_impls],
5159
        dtypes=[torch.float32],
5160
    )
5161
    def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
5162
        device = device.split(":")[0]
5163
        if device not in optim_info.supports_fused_on:
5164
            self.skipTest(
5165
                f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
5166
            )
5167
        optim_inputs = optim_info.optim_inputs_func(device=device)
5168
        optim_cls = optim_info.optim_cls
5169
        for optim_input in optim_inputs:
5170
            for _separate_unscale in (True, False):
5171
                kwargs = optim_input.kwargs
5172
                kwargs["fused"] = True
5173
                torch.manual_seed(20)
5174
                (
5175
                    mod_control,
5176
                    mod_scaling,
5177
                    opt_control,
5178
                    opt_scaling,
5179
                    data,
5180
                    loss_fn,
5181
                    _,
5182
                ) = _create_scaling_case(
5183
                    optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device
5184
                )
5185
                optimizer_kwargs = deepcopy(kwargs)
5186
                optimizer_kwargs["fused"] = False
5187
                if "lr" not in kwargs:
5188
                    # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr
5189
                    optimizer_kwargs["lr"] = 1.0
5190
                opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs)
5191
                scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0)
5192
                scaler_control = torch.amp.GradScaler(device, init_scale=128.0)
5193
                tracker = TensorTracker()
5194
                for input, target in data:
5195
                    opt_control.zero_grad()
5196
                    with torch.autocast(device_type=device, dtype=torch.half):
5197
                        output_control = mod_control(input)
5198
                        loss_control = loss_fn(output_control, target)
5199
                    scaler_control.scale(loss_control).backward()
5200
                    scaler_control.step(opt_control)
5201
                    scaler_control.update()
5202

5203
                    opt_scaling.zero_grad()
5204
                    with torch.autocast(device_type=device, dtype=torch.half):
5205
                        output_scaling = mod_scaling(input)
5206
                        loss_scaling = loss_fn(output_scaling, target)
5207
                    scaler_scaling.scale(loss_scaling).backward()
5208
                    if _separate_unscale:
5209
                        scaler_scaling.unscale_(opt_scaling)
5210
                    scaler_scaling.step(opt_scaling)
5211
                    scaler_scaling.update()
5212

5213
                    tracker.add(loss_control)
5214
                    tracker.pop_check_set(loss_scaling, self)
5215
                    for param_control, param_scaling in zip(
5216
                        mod_control.parameters(), mod_scaling.parameters()
5217
                    ):
5218
                        tracker.add(param_control.grad)
5219
                        tracker.pop_check_set(param_scaling.grad, self)
5220
                        tracker.add(param_control)
5221
                        tracker.pop_check_set(param_scaling, self)
5222

5223
                        state_control, state_scaling = (
5224
                            opt_control.state[param_control],
5225
                            opt_scaling.state[param_scaling],
5226
                        )
5227

5228
                        for k in state_control:
5229
                            actual = state_scaling[k]
5230
                            if k == "step":
5231
                                actual = actual.squeeze()
5232
                            tracker.add(state_control[k])
5233
                            tracker.pop_check_set(actual, self)
5234

5235
    @onlyCUDA
5236
    @parametrize("in_place_unscale", [False, True])
5237
    @optims(
5238
        [optim for optim in optim_db if "cuda" in optim.supports_fused_on],
5239
        dtypes=[torch.float32],
5240
    )
5241
    def test_grad_scaler_with_preset_grad_scale(
5242
        self, device, dtype, optim_info, in_place_unscale
5243
    ):
5244
        weight = torch.ones((5, 5), device="cuda", requires_grad=True)
5245
        weight.grad = torch.full_like(weight, fill_value=15)
5246
        opt = optim_info.optim_cls([weight], lr=0.1, fused=True)
5247
        scaler = torch.amp.GradScaler(init_scale=5)
5248

5249
        # simulate scaling a loss
5250
        scaler.scale(torch.ones(5))
5251

5252
        if in_place_unscale:
5253
            scaler.unscale_(opt)
5254
            # the gradient should have been divided in-place
5255
            self.assertEqual(weight.grad, torch.full_like(weight, fill_value=3))
5256

5257
        # the user sets a `grad_scale` value which should be fused with the optimizer step
5258
        opt.grad_scale = torch.Tensor([3]).cuda()
5259
        scaler.step(opt)
5260

5261
        # check that the user's grad_scale was respected (i.e. the gradient was divided by 5 * 3)
5262
        self.assertEqual(weight.grad, torch.full_like(weight, fill_value=1))
5263

5264
    @onlyCUDA
5265
    @unittest.skipIf(
5266
        not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
5267
    )
5268
    @parametrize("foreach, fused", [(False, False), (True, False), (False, True)])
5269
    @optims(
5270
        [
5271
            optim
5272
            for optim in optim_db
5273
            if "foreach" in optim.supported_impls and "cuda" in optim.supports_fused_on
5274
        ],
5275
        dtypes=[torch.float32],
5276
    )
5277
    def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused):
5278
        torch.cuda.empty_cache()
5279

5280
        scaler = torch.amp.GradScaler(device="cuda", init_scale=4.0)
5281
        g = torch.cuda.CUDAGraph()
5282
        s = torch.cuda.Stream()
5283

5284
        weight = torch.ones((100,), device="cuda", requires_grad=True)
5285
        opt = optim_info.optim_cls([weight], lr=0.1, foreach=foreach, fused=fused)
5286
        static_input = torch.ones_like(weight)
5287
        static_grad = torch.ones_like(weight)
5288

5289
        # warmup
5290
        s = torch.cuda.Stream()
5291
        s.wait_stream(torch.cuda.current_stream())
5292
        with torch.cuda.stream(s):
5293
            loss = (weight.half() * static_input).sum()
5294
            scaler.scale(loss).backward()
5295
        torch.cuda.current_stream().wait_stream(s)
5296

5297
        opt.zero_grad(set_to_none=True)
5298

5299
        # capture
5300
        with torch.cuda.stream(s):
5301
            g.capture_begin()
5302
            loss = (weight.half() * static_input).sum()
5303
            scaler.scale(loss).backward()
5304
            g.capture_end()
5305

5306
        input_vals = [5, 20000, 5, 40000]
5307
        # If the scale gets updated properly, these are the scale, growth tracker,
5308
        # and grad values we expect.
5309
        expected_scales = [4, 2, 2, 1]
5310
        expected_growth_trackers = [1, 0, 1, 0]
5311
        expected_grad_vals = [5 * 4, float("inf"), 5 * 2, float("inf")]
5312

5313
        for data, scale, growth_tracker, grad_val in zip(
5314
            input_vals, expected_scales, expected_growth_trackers, expected_grad_vals
5315
        ):
5316
            static_input.fill_(data)
5317
            g.replay()
5318
            self.assertEqual(weight.grad, torch.full_like(weight.grad, grad_val))
5319
            scaler.step(opt)
5320
            scaler.update()
5321
            self.assertEqual(scaler._scale, scale)
5322
            self.assertEqual(scaler._growth_tracker, growth_tracker)
5323

5324

5325
class TestGDS(TestCase):
5326
    def _get_tmp_dir_fs_type(self):
5327
        my_path = os.path.realpath("/tmp")
5328
        root_type = ""
5329
        for part in psutil.disk_partitions():
5330
            if part.mountpoint == "/":
5331
                root_type = part.fstype
5332
                continue
5333
            if part.mountpoint == my_path:
5334
                return part.fstype
5335
        return root_type
5336

5337
    @unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
5338
    def test_gds_read_write_tensors(self):
5339
        if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"):
5340
            self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem")
5341
        src1 = torch.randn(1024, device="cuda")
5342
        src2 = torch.randn(2, 1024, device="cuda")
5343
        torch.cuda.gds._gds_register_buffer(src1.untyped_storage())
5344
        torch.cuda.gds._gds_register_buffer(src2.untyped_storage())
5345
        dest1 = torch.empty(1024, device="cuda")
5346
        dest2 = torch.empty(2, 1024, device="cuda")
5347
        with TemporaryFileName() as f:
5348
            file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
5349
            file.save_storage(src1.untyped_storage(), offset=0)
5350
            file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
5351
            file.load_storage(dest1.untyped_storage(), offset=0)
5352
            file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
5353
        self.assertEqual(src1, dest1)
5354
        self.assertEqual(src2, dest2)
5355
        torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage())
5356
        torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
5357

5358

5359
instantiate_parametrized_tests(TestCuda)
5360
instantiate_parametrized_tests(TestCudaMallocAsync)
5361
instantiate_device_type_tests(TestCudaOptims, globals())
5362

5363
if __name__ == "__main__":
5364
    run_tests()
5365

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.