17
from copy import deepcopy
18
from itertools import product
19
from random import randint
25
from torch import inf, nan
26
from torch.cuda._memory_viz import (
32
from torch.testing._internal.autocast_test_lists import AutocastTestLists
33
from torch.testing._internal.common_cuda import (
35
_get_torch_cuda_version,
39
from torch.testing._internal.common_device_type import (
40
instantiate_device_type_tests,
42
onlyNativeDeviceTypes,
44
from torch.testing._internal.common_optimizers import (
45
_get_optim_inputs_including_global_cliquey_kwargs,
50
from torch.testing._internal.common_utils import (
55
instantiate_parametrized_tests,
63
NO_MULTIPROCESSING_SPAWN,
68
skipCUDAMemoryLeakCheckIf,
69
skipCUDANonDefaultStreamIf,
80
from torch.utils.checkpoint import checkpoint_sequential
81
from torch.utils.viz._cycles import observe_tensor_cycles
86
load_tests = load_tests
89
print("CUDA not available, skipping tests", file=sys.stderr)
93
import torchvision.models
94
from torchvision.models import resnet18
96
HAS_TORCHVISION = True
98
HAS_TORCHVISION = False
99
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
101
TEST_CUDAMALLOCASYNC = TEST_CUDA and (
102
torch.cuda.get_allocator_backend() == "cudaMallocAsync"
104
TEST_LARGE_TENSOR = TEST_CUDA
105
TEST_MEDIUM_TENSOR = TEST_CUDA
107
TEST_PYNVML = not torch.cuda._HAS_PYNVML
109
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9
110
TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9
111
TEST_BF16 = torch.cuda.is_bf16_supported()
116
@torch.testing._internal.common_utils.markDynamoStrictTest
117
class TestCuda(TestCase):
118
_do_cuda_memory_leak_check = True
119
_do_cuda_non_default_stream = True
120
FIFTY_MIL_CYCLES = 50000000
124
self.autocast_lists = AutocastTestLists(torch.device("cuda:0"))
127
del self.autocast_lists
131
def expandable_segments(self):
132
return EXPANDABLE_SEGMENTS
134
def test_pinned_memory_with_cudaregister(self):
135
torch.cuda.memory._set_allocator_settings(
136
"pinned_use_cuda_host_register:True,pinned_num_register_threads:8"
139
self.assertFalse(t.is_pinned())
141
pinned_t = torch.ones(1 << 21).pin_memory()
142
self.assertTrue(pinned_t.is_pinned())
143
pinned_t = torch.ones(1 << 24).pin_memory()
144
self.assertTrue(pinned_t.is_pinned())
145
except RuntimeError as e:
149
def test_pinned_memory_with_cudaregister_multithread(self):
152
threading.Thread(target=self.test_pinned_memory_with_cudaregister)
153
for t in range(num_threads)
155
for thread in threads:
157
for thread in threads:
160
def test_pinned_memory_empty_cache(self):
161
for alloc_settings in (True, False):
162
torch.cuda.memory._set_allocator_settings(
163
f"pinned_use_cuda_host_register:{alloc_settings}"
166
t = torch.ones(1024 * 1024, pin_memory=True)
167
self.assertTrue(t.is_pinned())
169
torch._C._host_emptyCache()
170
except RuntimeError as e:
174
def test_cudart_register(self):
176
self.assertFalse(t.is_pinned())
177
cudart = torch.cuda.cudart()
178
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
179
self.assertEqual(r, 0)
180
self.assertTrue(t.is_pinned())
181
r = cudart.cudaHostUnregister(t.data_ptr())
182
self.assertEqual(r, 0)
183
self.assertFalse(t.is_pinned())
185
def test_memory_allocation(self):
187
torch.cuda.empty_cache()
192
prev = torch.cuda.memory_allocated()
193
mem = torch.cuda.caching_allocator_alloc(size)
194
self.assertGreater(torch.cuda.memory_allocated(), prev)
197
torch.cuda.caching_allocator_delete(mem)
198
self.assertEqual(torch.cuda.memory_allocated(), prev)
200
def test_check_error(self):
202
torch.cuda.check_error(0)
204
with self.assertRaisesRegex(
205
torch.cuda.CudaError, "out of memory|hipErrorOutOfMemory"
207
torch.cuda.check_error(2)
209
def test_cuda_get_device_name(self):
211
current_device = torch.cuda.current_device()
212
current_device_name = torch.cuda.get_device_name(current_device)
213
device_name_None = torch.cuda.get_device_name(None)
214
self.assertEqual(current_device_name, device_name_None)
217
device_name_no_argument = torch.cuda.get_device_name()
218
self.assertEqual(current_device_name, device_name_no_argument)
220
def test_cuda_get_device_capability(self):
222
current_device = torch.cuda.current_device()
223
current_device_capability = torch.cuda.get_device_capability(current_device)
224
device_capability_None = torch.cuda.get_device_capability(None)
225
self.assertEqual(current_device_capability, device_capability_None)
228
device_capability_no_argument = torch.cuda.get_device_capability()
229
self.assertEqual(current_device_capability, device_capability_no_argument)
231
def test_out_of_memory(self):
232
tensor = torch.zeros(1024, device="cuda")
235
"would exceed allowed memory"
236
if TEST_CUDAMALLOCASYNC
237
else "Tried to allocate 800000000.00 GiB"
239
with self.assertRaisesRegex(RuntimeError, oom_regex):
240
torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="cuda")
242
with self.assertRaisesRegex(
243
RuntimeError, "Tried to allocate more than 1EB memory"
246
1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="cuda"
251
self.assertTrue((tensor == 1).all())
254
TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)"
257
def test_out_of_memory_retry(self):
258
torch.cuda.empty_cache()
259
total_memory = torch.cuda.get_device_properties(0).total_memory
261
"would exceed allowed memory"
262
if TEST_CUDAMALLOCASYNC
263
else "Tried to allocate"
265
size = int(total_memory * 0.5)
266
a = torch.empty(size, dtype=torch.int8, device="cuda")
267
with self.assertRaisesRegex(RuntimeError, oom_regex):
268
b = torch.empty(size, dtype=torch.int8, device="cuda")
270
b = torch.empty(size, dtype=torch.int8, device="cuda")
273
torch.cuda.empty_cache()
274
torch.cuda.reset_peak_memory_stats()
277
def test_set_per_process_memory_fraction(self):
279
with self.assertRaisesRegex(TypeError, "Invalid type"):
280
torch.cuda.set_per_process_memory_fraction(1)
281
with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
282
torch.cuda.set_per_process_memory_fraction(-0.1)
283
with self.assertRaisesRegex(ValueError, "Invalid fraction value"):
284
torch.cuda.set_per_process_memory_fraction(2.0)
286
tensor = torch.zeros(1024, device="cuda")
287
torch.cuda.empty_cache()
288
total_memory = torch.cuda.get_device_properties(0).total_memory
289
torch.cuda.set_per_process_memory_fraction(0.5, 0)
292
application = int(total_memory * 0.499) - torch.cuda.max_memory_reserved()
293
tmp_tensor = torch.empty(application, dtype=torch.int8, device="cuda")
295
torch.cuda.empty_cache()
297
application = int(total_memory * 0.5)
300
"would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
302
with self.assertRaisesRegex(RuntimeError, oom_regex):
303
torch.empty(application, dtype=torch.int8, device="cuda")
307
self.assertTrue((tensor == 1).all())
309
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "uuid attribute not yet available")
311
uuid = torch.cuda.get_device_properties(0).uuid
312
self.assertEqual(len(str(uuid)), 36)
313
self.assertEqual(len(uuid.bytes), 16)
315
def test_copy_non_blocking(self):
316
def _test_copy_non_blocking(a, b):
317
event = torch.cuda.Event()
318
a.copy_(b, non_blocking=True)
321
self.assertEqual(a, b)
324
x = torch.ones(10000000, dtype=torch.uint8).cuda()
325
y = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
326
_test_copy_non_blocking(x, y)
328
x = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
329
y = torch.ones(10000000, dtype=torch.uint8).cuda()
330
_test_copy_non_blocking(x, y)
333
x_base = torch.zeros(10000000, dtype=torch.uint8).pin_memory()
335
self.assertTrue(x.is_pinned())
336
self.assertTrue(x_base.is_pinned())
337
self.assertNotEqual(x_base.data_ptr(), x.data_ptr())
338
self.assertEqual(x_base.storage().data_ptr(), x.storage().data_ptr())
339
y = torch.ones(10000000 - 1, dtype=torch.uint8).cuda()
340
_test_copy_non_blocking(x, y)
342
def test_copy_non_blocking_type_conversion(self):
343
a = torch.ones(1, device="cuda")
344
b = torch.zeros(1, device="cpu", pin_memory=True)
345
c = torch.empty(1, device="cuda", dtype=torch.long)
346
torch.cuda._sleep(int(100 * get_cycles_per_ms()))
347
b.copy_(a, non_blocking=True)
348
c.copy_(b, non_blocking=True)
349
self.assertEqual(a, c, exact_dtype=False)
352
def test_to_non_blocking(self):
353
stream = torch.cuda.current_stream()
355
def _test_to_non_blocking(a, non_blocking, dst):
356
torch.cuda.synchronize()
359
torch.cuda._sleep(int(100 * get_cycles_per_ms()))
360
b = a.to(device=dst, non_blocking=non_blocking)
361
self.assertEqual(stream.query(), not non_blocking)
363
self.assertEqual(a, b)
364
self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu"))
366
for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)):
370
device="cuda" if dst == "cpu" else "cpu",
371
pin_memory=True if dst == "cuda" else False,
373
_test_to_non_blocking(src, try_non_blocking, dst)
375
def test_to_cpu_blocking_by_default(self):
376
src = torch.randn(1000000, device="cuda")
377
torch.cuda.synchronize()
378
torch.cuda._sleep(int(100 * get_cycles_per_ms()))
379
dst = src.to(device="cpu")
380
self.assertEqual(torch.cuda.current_stream().query(), True)
381
self.assertEqual(src, dst)
382
self.assertFalse(dst.is_pinned())
384
def test_serialization_array_with_storage(self):
385
x = torch.randn(5, 5).cuda()
386
y = torch.IntTensor(2, 5).fill_(0).cuda()
387
q = [x, y, x, y.storage()]
388
with tempfile.NamedTemporaryFile() as f:
391
q_copy = torch.load(f)
392
self.assertEqual(q_copy, q, atol=0, rtol=0)
394
self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
395
self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor))
396
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
397
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
398
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
399
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
401
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
404
TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async"
407
_get_torch_cuda_version() >= (12, 2),
408
"skipped as explicit workspace allocation is removed",
410
def test_cublas_workspace_explicit_allocation(self):
411
a = torch.randn(7, 7, device="cuda", requires_grad=False)
412
default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024
414
if torch.cuda.get_device_capability() == (9, 0):
415
default_workspace_size = 4096 * 8 * 1024
417
def check_workspace_size(inp):
418
torch._C._cuda_clearCublasWorkspaces()
419
start = torch.cuda.memory_stats()["active_bytes.all.allocated"]
420
with torch.no_grad():
421
torch.matmul(inp, inp)
422
finish = torch.cuda.memory_stats()["active_bytes.all.allocated"]
423
return finish - start
426
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
427
self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288)
430
os.environ["CUBLAS_WORKSPACE_CONFIG"] = "-1"
431
self.assertTrue(abs(check_workspace_size(a) - default_workspace_size) < 524288)
434
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":128:8:64:16:32:32"
435
self.assertTrue(abs(check_workspace_size(a) - (3072 * 1024)) < 524288)
437
torch._C._cuda_clearCublasWorkspaces()
439
def test_cublas_allow_tf32_get_set(self):
440
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
441
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
444
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
447
orig = torch.backends.cuda.matmul.allow_tf32
448
self.assertEqual(torch._C._get_cublas_allow_tf32(), orig)
449
torch.backends.cuda.matmul.allow_tf32 = not orig
450
self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
451
torch.backends.cuda.matmul.allow_tf32 = orig
453
def test_float32_matmul_precision_get_set(self):
454
orig = torch.get_float32_matmul_precision()
455
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
456
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
460
if not skip_tf32_cublas:
461
self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
462
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
464
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
465
for p in ("medium", "high"):
466
torch.set_float32_matmul_precision(p)
467
self.assertEqual(torch.get_float32_matmul_precision(), p)
468
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)
469
torch.set_float32_matmul_precision("highest")
470
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
471
self.assertFalse(torch.backends.cuda.matmul.allow_tf32)
472
torch.set_float32_matmul_precision(orig)
474
def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
475
orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
477
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig
479
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = not orig
481
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig
483
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
485
def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self):
486
orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
488
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig
490
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig
492
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig
494
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
496
def test_cudnn_allow_tf32_get_set(self):
497
with torch.backends.cudnn.flags(
498
enabled=None, benchmark=None, deterministic=None, allow_tf32=False
500
self.assertFalse(torch.backends.cudnn.allow_tf32)
501
with torch.backends.cudnn.flags(
502
enabled=None, benchmark=None, deterministic=None, allow_tf32=True
504
self.assertTrue(torch.backends.cudnn.allow_tf32)
506
def test_type_conversions(self):
507
x = torch.randn(5, 5)
508
self.assertIsInstance(x.float(), torch.FloatTensor)
509
self.assertIsInstance(x.cuda().double(), torch.cuda.DoubleTensor)
510
self.assertIsInstance(x.cuda().float(), torch.cuda.FloatTensor)
511
self.assertIsInstance(x.cuda().float().cpu(), torch.FloatTensor)
512
self.assertIsInstance(x.cuda().float().cpu().int(), torch.IntTensor)
515
self.assertIsInstance(y.float(), torch.FloatStorage)
516
self.assertIsInstance(y.cuda().double(), torch.cuda.DoubleStorage)
517
self.assertIsInstance(y.cuda().float(), torch.cuda.FloatStorage)
518
self.assertIsInstance(y.cuda().float().cpu(), torch.FloatStorage)
519
self.assertIsInstance(y.cuda().float().cpu().int(), torch.IntStorage)
521
@unittest.skip("was disabled due to not enough memory, but actually it always fail")
522
def test_arithmetic_large_tensor(self):
523
x = torch.empty(2**30, device="cuda")
526
self.assertEqual(x.sum(), 2**30)
529
self.assertEqual(x.sum(), 2**31)
533
self.assertEqual(x.sum(), 2**29)
537
self.assertEqual(x.sum(), 2**31)
541
self.assertEqual(x.sum(), 2**29)
543
def test_gather_bool(self):
544
t = torch.tensor([[False, True], [True, True]], device="cuda")
546
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]], device="cuda")),
547
torch.tensor([[False, False], [True, True]], device="cuda"),
550
def test_torch_manual_seed_seeds_cuda_devices(self):
551
with freeze_rng_state():
552
x = torch.zeros(4, 4).float().cuda()
554
self.assertEqual(torch.cuda.initial_seed(), 2)
557
y = x.clone().uniform_()
558
self.assertEqual(x, y)
559
self.assertEqual(torch.cuda.initial_seed(), 2)
561
def test_manual_seed(self):
562
with freeze_rng_state():
563
x = torch.zeros(4, 4).float().cuda()
564
torch.cuda.manual_seed(2)
565
self.assertEqual(torch.cuda.initial_seed(), 2)
567
a = torch.bernoulli(torch.full_like(x, 0.5))
568
torch.cuda.manual_seed(2)
569
y = x.clone().uniform_()
570
b = torch.bernoulli(torch.full_like(x, 0.5))
571
self.assertEqual(x, y)
572
self.assertEqual(a, b)
573
self.assertEqual(torch.cuda.initial_seed(), 2)
575
def test_specify_improper_device_name(self):
578
fname = "tempfile.pt"
580
with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
582
[torch.nn.Parameter(torch.randn(10, 10))],
584
_use_new_zipfile_serialization=True,
586
torch.load(fname, "cuda0")
588
if os.path.exists(fname):
591
def test_get_device_index(self):
592
from torch.cuda._utils import _get_device_index
594
with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
595
_get_device_index("cuda0", optional=True)
597
with self.assertRaisesRegex(ValueError, "Expected a cuda device"):
598
cpu_device = torch.device("cpu")
599
_get_device_index(cpu_device, optional=True)
601
def test_serialization_array_with_empty(self):
602
x = [torch.randn(4, 4).cuda(), torch.cuda.FloatTensor()]
603
with tempfile.NamedTemporaryFile() as f:
606
x_copy = torch.load(f)
607
for original, copy in zip(x, x_copy):
608
self.assertEqual(copy, original)
609
self.assertIs(type(copy), type(original))
610
self.assertEqual(copy.get_device(), original.get_device())
612
@skipCUDANonDefaultStreamIf(True)
613
def test_streams(self):
614
default_stream = torch.cuda.current_stream()
615
user_stream = torch.cuda.Stream()
616
self.assertEqual(torch.cuda.current_stream(), default_stream)
617
self.assertNotEqual(default_stream, user_stream)
618
self.assertEqual(default_stream.cuda_stream, 0)
619
self.assertNotEqual(user_stream.cuda_stream, 0)
620
with torch.cuda.stream(user_stream):
621
self.assertEqual(torch.cuda.current_stream(), user_stream)
622
self.assertTrue(user_stream.query())
623
tensor1 = torch.ByteTensor(5).pin_memory()
624
tensor2 = tensor1.cuda(non_blocking=True) + 1
625
default_stream.synchronize()
626
self.assertTrue(default_stream.query())
628
def test_stream_event_repr(self):
629
s = torch.cuda.current_stream()
630
self.assertTrue("torch.cuda.Stream" in s.__repr__())
631
e = torch.cuda.Event()
632
self.assertTrue("torch.cuda.Event" in e.__repr__())
634
self.assertTrue("torch.cuda.Event" in e.__repr__())
636
def test_events(self):
637
stream = torch.cuda.current_stream()
638
event = torch.cuda.Event(enable_timing=True)
639
self.assertTrue(event.query())
640
start_event = torch.cuda.Event(enable_timing=True)
641
stream.record_event(start_event)
642
torch.cuda._sleep(int(50 * get_cycles_per_ms()))
643
stream.record_event(event)
644
self.assertFalse(event.query())
646
self.assertTrue(event.query())
647
self.assertGreater(start_event.elapsed_time(event), 0)
649
def test_generic_stream_event(self):
650
stream = torch.Stream("cuda")
651
self.assertEqual(stream.device_index, torch.cuda.current_device())
652
cuda_stream = torch.cuda.Stream(
653
stream_id=stream.stream_id,
654
device_index=stream.device_index,
655
device_type=stream.device_type,
657
self.assertEqual(stream.stream_id, cuda_stream.stream_id)
658
self.assertNotEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
660
event1 = torch.Event("cuda", enable_timing=True)
661
event2 = torch.Event("cuda", enable_timing=True)
662
self.assertEqual(event1.event_id, 0)
663
a = torch.randn(1000)
664
b = torch.randn(1000)
665
with torch.cuda.stream(cuda_stream):
666
a_cuda = a.to("cuda", non_blocking=True)
667
b_cuda = b.to("cuda", non_blocking=True)
668
self.assertEqual(stream.stream_id, torch.cuda.current_stream().stream_id)
669
event1.record(stream)
671
self.assertTrue(event1.query())
672
c_cuda = a_cuda + b_cuda
675
self.assertTrue(event2.query())
676
self.assertNotEqual(event1.event_id, event2.event_id)
677
self.assertEqual(c_cuda.cpu(), a + b)
678
self.assertTrue(event1.elapsed_time(event2) > 0)
680
def test_record_stream(self):
681
cycles_per_ms = get_cycles_per_ms()
683
t = torch.FloatTensor([1, 2, 3, 4]).pin_memory()
684
result = torch.cuda.FloatTensor(t.size())
685
stream = torch.cuda.Stream()
690
with torch.cuda.stream(stream):
691
tmp = t.cuda(non_blocking=True)
692
ptr[0] = tmp.data_ptr()
693
torch.cuda.current_stream().wait_stream(stream)
694
tmp.record_stream(torch.cuda.current_stream())
695
torch.cuda._sleep(int(50 * cycles_per_ms))
699
with torch.cuda.stream(stream):
700
tmp2 = torch.cuda.FloatTensor(t.size())
703
tmp2.data_ptr(), ptr[0], msg="allocation re-used to soon"
706
self.assertEqual(result.tolist(), [1, 2, 3, 4])
708
if not TEST_CUDAMALLOCASYNC:
711
torch.cuda.current_stream().synchronize()
712
with torch.cuda.stream(stream):
713
tmp3 = torch.cuda.FloatTensor(t.size())
714
self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not re-used")
716
def test_record_stream_on_shifted_view(self):
722
stream_alloc = torch.cuda.Stream()
723
with torch.cuda.stream(stream_alloc):
724
base = torch.cuda.FloatTensor([10, 10])
728
assert view.storage_offset() > 0
730
stream_record = torch.cuda.Stream()
731
with torch.cuda.stream(stream_record):
732
torch.cuda._sleep(int(50 * get_cycles_per_ms()))
734
view.record_stream(stream_record)
737
data_ptr = base.data_ptr()
741
stream_alloc.synchronize()
743
with torch.cuda.stream(stream_alloc):
744
try_realloc = torch.cuda.FloatTensor([10, 10])
746
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
748
def test_noncontiguous_pinned_memory(self):
750
x = torch.arange(0, 10).view((2, 5))
751
self.assertEqual(x.t(), x.t().pin_memory())
753
def test_caching_pinned_memory(self):
754
cycles_per_ms = get_cycles_per_ms()
757
t = torch.FloatTensor([1]).pin_memory()
760
t = torch.FloatTensor([1]).pin_memory()
761
self.assertEqual(t.data_ptr(), ptr, msg="allocation not reused")
764
gpu_tensor = torch.cuda.FloatTensor([0])
765
torch.cuda._sleep(int(1000 * cycles_per_ms))
766
gpu_tensor.copy_(t, non_blocking=True)
768
t = torch.FloatTensor([1]).pin_memory()
769
self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon")
770
self.assertEqual(list(gpu_tensor), [1])
772
def test_caching_allocator_record_stream_oom(self):
773
"""allocations delayed by a record_stream call should still be freed on
774
an out-of-memory in cuda_malloc_retry. see issue #19219"""
775
stream = torch.cuda.Stream()
777
with torch.cuda.stream(stream):
778
y = torch.zeros(40 * 1024 * 1024, device="cuda")
781
x = torch.empty(40 * 1024 * 1024, device="cuda")
782
with torch.cuda.stream(stream):
785
x.record_stream(stream)
790
torch.cuda.empty_cache()
793
def test_reduction_gpu_memory_accessing(self):
794
x = torch.ones(512, 8, dtype=torch.float32, device="cuda")
797
def test_sum_fp16(self):
798
x = torch.zeros(10, device="cuda", dtype=torch.float16)
799
self.assertEqual(x.sum(), 0)
801
x = torch.ones(65504, device="cuda", dtype=torch.float16)
802
self.assertEqual(x.sum(), 65504)
803
self.assertEqual(x.sum(dtype=torch.float32), 65504)
805
x = torch.ones(65536, device="cuda", dtype=torch.float16)
806
self.assertEqual(x.sum(dtype=torch.float32), 65536)
808
a = torch.zeros(1203611).bernoulli_(0.0005)
809
x = a.to(device="cuda", dtype=torch.float16)
810
self.assertEqual(x.sum().item(), a.sum().item())
812
a = torch.zeros(100, 121, 80).bernoulli_(0.0005)
813
x = a.to(device="cuda", dtype=torch.float16)
814
self.assertEqual(x.sum((0, 2)).float().cpu(), a.sum((0, 2)))
816
def test_mean_fp16(self):
817
x = torch.ones(65536, device="cuda", dtype=torch.float16)
818
self.assertEqual(x.mean(), 1)
820
x = torch.ones(65536, device="cuda", dtype=torch.float16)
821
self.assertEqual(x.mean(dtype=torch.float32), 1)
823
def test_prod_large(self):
825
x = torch.ones(240000, device="cuda", dtype=torch.float32)
826
self.assertEqual(x.prod(), 1)
829
for dtype in [torch.cfloat, torch.cdouble]:
830
x = torch.ones(240000, device="cuda", dtype=dtype) * (0 + 1j)
831
self.assertEqual(x.prod(), 1)
833
def test_multinomial_ext(self):
835
freqs = torch.cuda.FloatTensor(
847
0.027680952101945877,
848
0.033176131546497345,
849
0.046052902936935425,
856
0.049702685326337814,
857
0.027557924389839172,
858
0.018125897273421288,
859
0.011851548217236996,
860
0.010252203792333603,
861
0.007422595750540495,
862
0.005372154992073774,
864
0.0036087757907807827,
865
0.0035267581697553396,
866
0.0018864056328311563,
867
0.0024605290964245796,
868
0.0022964938543736935,
869
0.0018453967059031129,
870
0.0010662291897460818,
871
0.0009842115687206388,
872
0.00045109697384759784,
873
0.0007791675161570311,
874
0.00020504408166743815,
875
0.00020504408166743815,
876
0.00020504408166743815,
877
0.00012302644609007984,
879
0.00012302644609007984,
880
4.100881778867915e-05,
890
torch.cuda.manual_seed(11042)
891
sample = torch.multinomial(freqs, 1000, True)
892
self.assertNotEqual(freqs[sample].min(), 0)
894
p = torch.zeros(3421, 2, device="cuda", dtype=torch.float)
896
torch.cuda.manual_seed(5214)
897
r = torch.multinomial(p, 1)
898
self.assertNotEqual(r.min().item(), 0)
901
torch.cuda.manual_seed(33)
902
probs = torch.randn(1000000, device="cuda").clamp(min=0) * 3e-5
903
samples = probs.multinomial(1000000, replacement=True)
904
self.assertGreater(probs[samples].min().item(), 0)
906
def _spawn_test_multinomial_invalid_probs_cuda(self, probs):
910
p = subprocess.Popen(
917
from torch import inf, nan
919
with torch.random.fork_rng(devices=[0]):
920
torch.multinomial(torch.tensor({probs}).to('cuda'), 2, replacement=True)
921
torch.cuda.synchronize()
922
sys.exit(-1) # Should not be reached
923
except RuntimeError as e:
927
stdout=subprocess.PIPE,
928
stderr=subprocess.PIPE,
929
universal_newlines=True,
931
out, err = p.communicate(timeout=10)
933
except subprocess.TimeoutExpired as e:
935
out, err = p.communicate()
936
expected_messages = [
937
"device-side assert triggered",
939
"HSA_STATUS_ERROR_EXCEPTION",
940
"Device-side assertion",
942
self.assertTrue(any(msg in out or msg in err for msg in expected_messages))
945
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
947
NO_MULTIPROCESSING_SPAWN,
948
"Disabled for environments that \
949
don't support multiprocessing with spawn start method",
951
def test_multinomial_invalid_probs_cuda(self):
952
self._spawn_test_multinomial_invalid_probs_cuda([1.0, -1.0, 1.0])
953
self._spawn_test_multinomial_invalid_probs_cuda([1.0, inf, 1.0])
954
self._spawn_test_multinomial_invalid_probs_cuda([1.0, -inf, 1.0])
955
self._spawn_test_multinomial_invalid_probs_cuda([1.0, 1.0, nan])
959
os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())
961
def _spawn_method(self, method, arg):
962
ctx = torch.multiprocessing.get_context("spawn")
963
with ctx.Pool(1, initializer=self._mute_init) as pool:
964
errors = pool.map(method, [arg])
966
if "device-side assert triggered" not in str(e):
970
def _test_index_bounds_cuda(idx):
971
x = torch.arange(10, device="cuda")
973
y = x[torch.tensor([idx])]
974
return f"x[torch.tensor([{idx})]={y}"
975
except RuntimeError as err:
980
NO_MULTIPROCESSING_SPAWN,
981
"Disabled for environments that \
982
don't support multiprocessing with spawn start method",
985
def test_index_out_of_bounds_exception_cuda(self):
986
test_method = TestCuda._test_index_bounds_cuda
989
test_method(1), "x[torch.tensor([1)]=tensor([1], device='cuda:0')"
992
self._spawn_method(test_method, 11)
995
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
997
def test_huge_index(self):
998
src = torch.empty(15000000, 45, device="cuda", dtype=torch.long).random_(
1001
idx = torch.randperm(src.shape[0], device="cuda")
1003
res_cpu = src.cpu()[idx.cpu()]
1004
self.assertEqual(res.cpu(), res_cpu)
1006
def test_randint_randomness_for_large_range(self) -> None:
1011
high = 6_000_000_000
1013
def run(dev: torch.device) -> int:
1016
gen = torch.Generator(device=dev)
1019
0, high, [size], device=dev, generator=gen, dtype=torch.int64
1022
0, high, [size], device=dev, generator=gen, dtype=torch.int64
1024
return torch.stack([t1, t2]).unique().shape[0]
1027
assert abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000
1029
@parametrize("dtype", [torch.float32, torch.double])
1030
def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None:
1034
def run(func, dev: torch.device, dtype: torch.dtype) -> int:
1038
gen = torch.Generator(device=dev)
1040
t1 = func((size,), device=dev, generator=gen, dtype=dtype)
1041
t2 = func((size,), device=dev, generator=gen, dtype=dtype)
1042
return torch.stack([t1, t2]).unique().shape[0]
1045
for func in [torch.rand, torch.randn]:
1047
run(func, torch.device("cuda"), dtype)
1048
- run(func, torch.device("cpu"), dtype)
1050
assert deviation < 50_000, deviation
1052
def test_min_max_inits(self):
1055
x = torch.cuda.ByteTensor([0])
1056
y = torch.cuda.ByteTensor([255])
1057
expected = torch.cuda.LongTensor([0])[0]
1060
self.assertEqual(v, expected)
1063
self.assertEqual(v, expected)
1065
def test_nvtx(self):
1067
torch.cuda.nvtx.range_push("foo")
1068
torch.cuda.nvtx.mark("bar")
1069
torch.cuda.nvtx.range_pop()
1070
range_handle = torch.cuda.nvtx.range_start("range_start")
1071
torch.cuda.nvtx.range_end(range_handle)
1073
def test_bincount_ext(self):
1075
input_size = (100000,)
1076
w = torch.randn(input_size, dtype=torch.double, device="cuda")
1079
t = torch.randint(50, input_size, dtype=torch.int8, device="cuda")
1080
self.assertEqual(t.cpu().bincount(), t.bincount())
1081
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
1085
t = torch.randint(50000, input_size, dtype=torch.int64, device="cuda")
1086
self.assertEqual(t.cpu().bincount(), t.bincount())
1087
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
1089
t = torch.zeros([10], dtype=torch.int32, device="cuda")
1093
counted = t.bincount(minlength=65536)
1094
self.assertEqual(torch.sum(counted), 10)
1096
def test_tiny_half_norm_(self):
1097
a = torch.arange(25).cuda().float()
1100
self.assertGreater(b.norm().item(), 0)
1102
def test_norm_type_conversion(self):
1103
a = torch.ones(65536).cuda().half()
1104
self.assertEqual(a.norm(p=0, dtype=torch.float32), 65536)
1106
def test_cuda_memory_leak_detection_propagates_errors(self):
1107
with self.assertRaisesRegex(
1108
RuntimeError, r"The size of tensor a \(3\) must match"
1110
with self.assertLeaksNoCudaTensors():
1111
x = torch.randn(3, 1, device="cuda")
1112
y = torch.randn(2, 1, device="cuda")
1115
@unittest.skipIf(not TEST_MEDIUM_TENSOR, "not enough memory")
1117
def test_cuda_kernel_loop_overflow(self):
1120
x = torch.randn(1, 1, 1, 2**30 + 1, dtype=torch.float16, device="cuda")
1121
expected = x[0, 0, 0, 2**30]
1122
y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1123
torch.cuda.synchronize()
1124
self.assertEqual(y[0, 0, 0, 2**30], expected)
1126
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
1129
def test_cuda_kernel_loop_overflow_large(self):
1131
x = torch.randn(1, 1, 1, 2**31, dtype=torch.float16, device="cuda")
1132
with self.assertRaisesRegex(RuntimeError, "integer out of range"):
1133
y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1137
x = torch.randn(1, 1, 1, 2**31 - 1, dtype=torch.float16, device="cuda")
1138
expected = x[0, 0, 0, 2**31 - 2]
1139
y = torch.nn.functional.avg_pool2d(x, kernel_size=1)
1140
torch.cuda.synchronize()
1141
self.assertEqual(y[0, 0, 0, 2**31 - 2], expected)
1144
def _make_multiply_in_stream(self):
1145
class MultiplyInStream(torch.autograd.Function):
1147
def forward(ctx, x, val):
1149
ctx.stream = torch.cuda.current_stream()
1153
def backward(ctx, grad):
1154
self.assertEqual(torch.cuda.current_stream(), ctx.stream)
1156
torch.cuda._sleep(1000 * 5000)
1157
return grad * ctx.val, None
1159
return MultiplyInStream
1161
@skipCUDANonDefaultStreamIf(True)
1162
def test_streaming_backwards_sync(self):
1163
default_stream = torch.cuda.current_stream()
1164
stream = torch.cuda.Stream()
1166
MultiplyInStream = self._make_multiply_in_stream()
1170
x = torch.randn(5, 5, device="cuda", requires_grad=True)
1171
with torch.cuda.stream(stream):
1172
stream.wait_stream(default_stream)
1173
output = MultiplyInStream.apply(x, 2)
1174
output.sum().backward()
1176
default_stream.wait_stream(stream)
1177
self.assertEqual(x.grad, torch.ones_like(x) * 2)
1178
self.assertEqual(torch.cuda.current_stream(), default_stream)
1182
bwd_ambient_stream = torch.cuda.Stream()
1183
x = torch.randn(5, 5, device="cuda", requires_grad=True)
1184
with torch.cuda.stream(stream):
1185
stream.wait_stream(default_stream)
1186
output = MultiplyInStream.apply(x, 3)
1187
with torch.cuda.stream(bwd_ambient_stream):
1188
bwd_ambient_stream.wait_stream(stream)
1189
output.sum().backward()
1193
self.assertEqual(x.grad, torch.ones_like(x) * 3)
1194
self.assertEqual(torch.cuda.current_stream(), bwd_ambient_stream)
1197
@skipIfRocm(msg="flakey on ROCm https://github.com/pytorch/pytorch/issues/53190")
1198
def test_streaming_backwards_multiple_streams(self):
1199
MultiplyInStream = self._make_multiply_in_stream()
1201
class StreamModel(torch.nn.Module):
1202
def __init__(self) -> None:
1204
self.event = torch.cuda.Event()
1205
self.stream0 = torch.cuda.Stream()
1206
self.stream1 = torch.cuda.Stream()
1208
def forward(self, x, x_first_use_on_ambient):
1209
if x_first_use_on_ambient:
1211
self.stream0.wait_stream(torch.cuda.current_stream())
1212
self.stream1.wait_stream(torch.cuda.current_stream())
1213
with torch.cuda.stream(self.stream0):
1214
if not x_first_use_on_ambient:
1216
y0 = MultiplyInStream.apply(x0, 2)
1217
self.event.record(stream=torch.cuda.current_stream())
1219
with torch.cuda.stream(self.stream1):
1220
y1 = MultiplyInStream.apply(x, 3)
1221
self.stream1.wait_event(self.event)
1224
stream = torch.cuda.Stream()
1226
for x_first_use_on_ambient in (True, False):
1229
for out_of_place, iters in ((True, 1), (False, 1), (False, 5)):
1230
with torch.cuda.stream(stream):
1231
x = torch.randn(5, 5, device="cuda", requires_grad=True)
1232
model = StreamModel().cuda()
1234
lambda grad: self.assertEqual(
1235
torch.cuda.current_stream(),
1236
stream if x_first_use_on_ambient else model.stream0,
1239
for p in model.parameters():
1240
self.assertTrue(p.grad is None)
1241
for i in range(iters):
1242
loss = model(x, x_first_use_on_ambient).sum()
1244
x_grad = torch.autograd.grad((loss,), (x,))[0]
1248
torch.cuda.current_stream().wait_stream(stream)
1251
self.assertEqual(x_grad, torch.ones_like(x) * 5 * iters)
1253
self.assertEqual(x.grad, torch.ones_like(x) * 5 * iters)
1255
def test_streaming_backwards_sync_graph_root(self):
1261
fwd_bwd_op_stream = torch.cuda.Stream()
1262
bwd_ambient_stream = torch.cuda.Stream()
1264
self.assertTrue(fwd_bwd_op_stream != bwd_ambient_stream)
1268
a = torch.full((size,), 2.0, device="cuda", requires_grad=True)
1269
b = torch.full((size,), 3.0, device="cuda", requires_grad=True)
1274
for trial in range(5):
1275
torch.cuda.synchronize()
1276
a.grad = b.grad = None
1277
with torch.cuda.stream(fwd_bwd_op_stream):
1280
with torch.cuda.stream(bwd_ambient_stream):
1281
torch.cuda.synchronize()
1283
torch.cuda._sleep(int(50 * get_cycles_per_ms()))
1285
grad = torch.full((size,), float(trial + 1), device="cuda")
1289
torch.autograd.backward(tensors=c, grad_tensors=grad)
1295
torch.cuda.synchronize()
1296
with torch.no_grad():
1297
self.assertEqual(a.grad, grad * b)
1298
self.assertEqual(b.grad, grad * a)
1300
def test_streaming_backwards_callback(self):
1304
MultiplyInStream = self._make_multiply_in_stream()
1307
a = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True)
1308
b = torch.full((size,), 1, device="cuda", dtype=torch.float, requires_grad=True)
1310
s0 = torch.cuda.Stream()
1311
s1 = torch.cuda.Stream()
1312
s2 = torch.cuda.Stream()
1317
s0.wait_stream(torch.cuda.current_stream())
1318
with torch.cuda.stream(s0):
1319
c = MultiplyInStream.apply(a, 2)
1321
s1.wait_stream(torch.cuda.current_stream())
1322
with torch.cuda.stream(s1):
1323
d = MultiplyInStream.apply(b, 3)
1327
def clone_leaf_grads():
1328
stash.append(a.grad.clone())
1329
stash.append(b.grad.clone())
1333
lambda grad: torch.autograd.Variable._execution_engine.queue_callback(
1339
with torch.cuda.stream(s2):
1343
self.assertEqual(stash[0], torch.full_like(a, 6))
1344
self.assertEqual(stash[1], torch.full_like(a, 6))
1348
"In ROCm, kernel asserts are disabled due to performance overhead",
1350
def test_fixed_cuda_assert_async(self):
1351
with self.assertRaisesRegex(
1352
RuntimeError, "Boolean value of Tensor with no values is ambiguous"
1354
torch._assert_async(torch.tensor([], device="cuda"))
1355
with self.assertRaisesRegex(
1357
"Boolean value of Tensor with more than one value is ambiguous",
1359
torch._assert_async(torch.tensor([0, 0], device="cuda"))
1361
torch._assert_async(torch.tensor(1, device="cuda"))
1362
torch._assert_async(torch.tensor(0.1, device="cuda"))
1363
torch._assert_async(torch.tensor(-0.1, device="cuda"))
1364
torch._assert_async(torch.tensor(True, device="cuda"))
1365
torch._assert_async(torch.tensor(0 + 0.1j, device="cuda"))
1368
"torch._assert_async(torch.tensor(0, device='cuda'))",
1369
"torch._assert_async(torch.tensor(0.0, device='cuda'))",
1370
"torch._assert_async(torch.tensor(False, device='cuda'))",
1371
"torch._assert_async(torch.tensor(0 + 0j, device='cuda'))",
1376
for stmt in fail_stmts:
1377
with self.subTest(stmt=stmt):
1378
r = subprocess.call(
1386
torch.cuda.synchronize()
1390
self.assertTrue(r != 0)
1392
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "FAIL")
1393
def test_cublas_multiple_threads_same_device(self):
1402
weight = torch.ones((size, size), device="cuda")
1404
barrier = threading.Barrier(num_threads)
1407
my_stream = torch.cuda.Stream()
1411
torch.cuda.synchronize()
1414
with torch.cuda.stream(my_stream):
1415
for i in range(test_iters):
1424
results[t] = torch.mm(results[t], weight)
1425
results[t].div_(float(size))
1426
torch.cuda.synchronize()
1428
for _ in range(trials):
1429
for t in range(num_threads):
1430
results[t] = torch.ones((size, size), device="cuda")
1433
threading.Thread(target=_worker, args=(t,)) for t in range(num_threads)
1436
for thread in threads:
1438
for thread in threads:
1441
for t in range(num_threads):
1442
self.assertEqual(results[t].sum().item(), size * size)
1445
@unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows (see issue 57401)")
1446
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1448
def test_cudnn_multiple_threads_same_device(self):
1452
weight = torch.ones((1, 1, 2, 2), device="cuda")
1459
barrier = threading.Barrier(num_threads)
1461
with torch.backends.cudnn.flags(enabled=True):
1464
my_stream = torch.cuda.Stream()
1468
torch.cuda.synchronize()
1471
with torch.cuda.stream(my_stream):
1472
for _ in range(test_iters):
1481
results[t] = torch.nn.functional.conv2d(
1482
results[t], weight, padding=0
1484
results[t].div_(4.0)
1485
torch.cuda.synchronize()
1487
for _ in range(trials):
1488
for t in range(num_threads):
1489
results[t] = torch.ones((1, 1, 2048, 2048), device="cuda")
1492
threading.Thread(target=_worker, args=(t,))
1493
for t in range(num_threads)
1496
for thread in threads:
1498
for thread in threads:
1501
for t in range(num_threads):
1503
results[t].sum().item(),
1504
(2048 - test_iters) * (2048 - test_iters),
1507
def test_cusparse_multiple_threads_same_device(self):
1513
def ones_sparse(size):
1514
a = torch.arange(size, device="cuda")
1515
indices = torch.cartesian_prod(a, a).t()
1516
values = torch.ones(size * size, device="cuda")
1517
return torch.sparse_coo_tensor(indices, values)
1519
weight = ones_sparse(size)
1521
barrier = threading.Barrier(num_threads)
1524
my_stream = torch.cuda.Stream()
1528
torch.cuda.synchronize()
1531
with torch.cuda.stream(my_stream):
1532
for i in range(test_iters):
1541
results[t] = weight.mm(results[t])
1542
results[t].div_(float(size))
1543
torch.cuda.synchronize()
1545
for _ in range(trials):
1546
for t in range(num_threads):
1547
results[t] = torch.ones((size, size), device="cuda")
1550
threading.Thread(target=_worker, args=(t,)) for t in range(num_threads)
1553
for thread in threads:
1555
for thread in threads:
1558
for t in range(num_threads):
1559
self.assertEqual(results[t].sum().item(), size * size)
1561
def _run_autocast_outofplace(
1562
self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None
1565
def cast(val, to_type):
1566
if isinstance(val, torch.Tensor):
1567
return val.to(to_type) if val.is_floating_point() else val
1568
elif isinstance(val, collections.abc.Iterable):
1569
return type(val)(cast(v, to_type) for v in val)
1573
if add_kwargs is None:
1575
fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16
1576
self.assertFalse(torch.is_autocast_enabled())
1577
with torch.autocast("cuda", dtype=fast_dtype):
1578
self.assertTrue(torch.is_autocast_enabled())
1580
out_type = out_type if out_type is not None else run_as_type
1581
output = output_method = None
1584
if module is not None and hasattr(module, op):
1585
output = getattr(module, op)(*args, **add_kwargs)
1586
if isinstance(output, torch.Tensor):
1588
out_type == output.dtype,
1589
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
1593
if hasattr(torch.Tensor, op):
1594
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
1595
if isinstance(output_method, torch.Tensor):
1597
out_type == output_method.dtype,
1598
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
1602
(output is not None) or (output_method is not None),
1603
f"{op} not found as an attribute on either Tensor or the requested module {module}",
1608
def compare(first, second):
1609
if isinstance(first, torch.Tensor):
1610
return torch.equal(first, second)
1611
elif isinstance(first, collections.abc.Iterable):
1612
return all(compare(f, s) for f, s in zip(first, second))
1614
return first == second
1617
if (output is not None) and (output_method is not None):
1618
self.assertTrue(type(output) == type(output_method))
1619
comparison = compare(output, output_method)
1621
comparison, f"torch.{op} result did not match Tensor.{op} result"
1626
output_to_compare = output if output is not None else output_method
1627
with torch.autocast("cuda", enabled=False):
1628
self.assertFalse(torch.is_autocast_enabled())
1630
if module is not None and hasattr(module, op):
1631
control = getattr(module, op)(
1632
*cast(args, run_as_type), **add_kwargs
1635
control = getattr(args[0].to(run_as_type), op)(
1636
*cast(args[1:], run_as_type), **add_kwargs
1638
self.assertTrue(type(output_to_compare) == type(control))
1639
comparison = compare(output_to_compare, control)
1640
self.assertTrue(comparison, f"torch.{op} result did not match control")
1641
self.assertTrue(torch.is_autocast_enabled())
1642
self.assertFalse(torch.is_autocast_enabled())
1644
def args_maybe_kwargs(self, op_with_args):
1645
if len(op_with_args) == 2:
1646
return op_with_args[0], op_with_args[1], {}
1648
return op_with_args[0], op_with_args[1], op_with_args[2]
1650
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1651
def test_autocast_torch_fp16(self):
1652
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1653
for op_with_args in self.autocast_lists.torch_fp16:
1655
op, args = op_with_args[0], op_with_args[1]
1656
if len(op_with_args) == 3:
1657
skip_test = op_with_args[2]
1659
self._run_autocast_outofplace(op, args, torch.float16)
1661
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1662
def test_autocast_torch_bf16(self):
1663
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1664
for op_with_args in self.autocast_lists.torch_fp16:
1666
op, args = op_with_args[0], op_with_args[1]
1667
if len(op_with_args) == 3:
1668
skip_test = op_with_args[2]
1669
should_error_from_cudnn = "cudnn" in op and (
1670
"TORCH_CUDNN_V8_API_DISABLED" in os.environ
1671
and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"])
1672
or torch.cuda.get_device_capability() < (8, 0)
1674
should_error_from_not_implemented = should_error_from_cudnn
1676
if should_error_from_not_implemented:
1677
with self.assertRaises(
1679
msg=str(op) + " should not be supported for bfloat16!",
1681
self._run_autocast_outofplace(op, args, torch.bfloat16)
1683
if torch.cuda.is_bf16_supported():
1684
self._run_autocast_outofplace(op, args, torch.bfloat16)
1686
with self.assertRaisesRegex(
1687
RuntimeError, "Device does not support bfloat16"
1689
self._run_autocast_outofplace(op, args, torch.bfloat16)
1691
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1692
def test_autocast_torch_fp32(self):
1693
for op_with_args in self.autocast_lists.torch_fp32:
1694
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
1695
self._run_autocast_outofplace(
1696
op, args, torch.float32, add_kwargs=maybe_kwargs
1699
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1700
def test_autocast_torch_need_autocast_promote(self):
1701
for op, args in self.autocast_lists.torch_need_autocast_promote:
1702
self._run_autocast_outofplace(op, args, torch.float32)
1704
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1705
def test_autocast_torch_expect_builtin_promote(self):
1706
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
1707
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
1709
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1710
def test_autocast_nn_fp16(self):
1711
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1712
for op, args in self.autocast_lists.nn_fp16:
1713
self._run_autocast_outofplace(
1714
op, args, torch.float16, module=torch._C._nn
1717
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1718
def test_autocast_nn_bf16(self):
1719
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1720
for op, args in self.autocast_lists.nn_fp16:
1721
if torch.cuda.is_bf16_supported():
1722
self._run_autocast_outofplace(
1723
op, args, torch.bfloat16, module=torch._C._nn
1726
with self.assertRaisesRegex(
1727
RuntimeError, "Device does not support bfloat16"
1729
self._run_autocast_outofplace(
1730
op, args, torch.bfloat16, module=torch._C._nn
1733
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1734
def test_autocast_nn_fp32(self):
1735
for op, args in self.autocast_lists.nn_fp32:
1736
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn)
1738
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1739
def test_autocast_linalg_fp16(self):
1740
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1741
for op, args in self.autocast_lists.linalg_fp16:
1742
self._run_autocast_outofplace(
1743
op, args, torch.float16, module=torch._C._linalg
1746
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1747
def test_autocast_methods_fp16(self):
1748
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1749
for op, args in self.autocast_lists.methods_fp16:
1750
self._run_autocast_outofplace(op, args, torch.float16, module=None)
1752
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1753
def test_autocast_methods_fp32(self):
1754
for op, args in self.autocast_lists.methods_fp32:
1755
self._run_autocast_outofplace(op, args, torch.float32, module=None)
1757
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1758
def test_autocast_methods_expect_builtin_promote(self):
1759
for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
1760
self._run_autocast_outofplace(
1761
op, args, torch.float32, module=None, out_type=out_type
1764
def test_autocast_banned(self):
1765
with torch.autocast("cuda"):
1766
for op, args, module in self.autocast_lists.banned:
1767
with self.assertRaises(RuntimeError):
1768
getattr(module, op)(*args)
1770
def test_autocast_ignored_types(self):
1771
with torch.autocast("cuda"):
1772
for ignore_type in (torch.double, torch.int32):
1773
a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
1774
b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
1775
c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0")
1779
if ignore_type is torch.double:
1780
with self.assertRaises(RuntimeError):
1781
torch.mm(a_ignore, c_16)
1782
with torch.autocast("cuda", enabled=False):
1783
type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
1785
torch.mm(a_ignore, b_ignore).dtype is type_no_autocast
1789
with torch.autocast("cuda", enabled=False):
1790
type_no_autocast = torch.pow(a_ignore, 2.0).dtype
1791
self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)
1794
with torch.autocast("cuda", enabled=False):
1795
type_no_autocast = torch.sum(a_ignore).dtype
1796
self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)
1800
if ignore_type is torch.double:
1801
with torch.autocast("cuda", enabled=False):
1802
type_no_autocast = torch.norm(a_ignore).dtype
1803
self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)
1805
def test_autocast_custom_enabled(self):
1806
class MyMM(torch.autograd.Function):
1808
@torch.amp.custom_fwd(device_type="cuda")
1809
def forward(ctx, a, b):
1810
self.assertTrue(a.dtype is torch.float32)
1811
self.assertTrue(b.dtype is torch.float32)
1812
self.assertTrue(torch.is_autocast_enabled())
1813
ctx.save_for_backward(a, b)
1817
@torch.amp.custom_bwd(device_type="cuda")
1818
def backward(ctx, grad):
1819
self.assertTrue(torch.is_autocast_enabled())
1820
a, b = ctx.saved_tensors
1821
a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad)
1822
self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype)
1823
return a_grad, b_grad
1827
x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
1828
y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True)
1830
dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,)
1831
for dtype in dtypes:
1832
with torch.cuda.amp.autocast(dtype=dtype):
1834
self.assertTrue(output.dtype is dtype)
1838
def test_autocast_custom_cast_inputs(self):
1839
class MyMM(torch.autograd.Function):
1841
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
1842
def forward(ctx, a, container, expect_type):
1844
self.assertTrue(a.dtype is expect_type)
1845
self.assertTrue(b.dtype is expect_type)
1846
self.assertFalse(torch.is_autocast_enabled())
1847
ctx.save_for_backward(a, b)
1851
@torch.amp.custom_bwd(device_type="cuda")
1852
def backward(ctx, grad):
1853
self.assertFalse(torch.is_autocast_enabled())
1854
a, b = ctx.saved_tensors
1855
return grad.mm(b.t()), None, None
1859
x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
1867
(8, 8), device="cuda", dtype=torch.float16, requires_grad=False
1872
with torch.autocast("cuda"):
1873
output = mymm(x, y, torch.float32)
1874
self.assertTrue(output.dtype is torch.float32)
1879
output = mymm(x, y, torch.float16)
1880
self.assertTrue(output.dtype is torch.float16)
1884
def test_autocast_custom_deprecated_warning(self):
1885
with warnings.catch_warnings(record=True) as w:
1887
class MyMM(torch.autograd.Function):
1889
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
1890
def forward(ctx, x, y):
1891
ctx.save_for_backward(x, y)
1892
self.assertFalse(torch.is_autocast_enabled())
1896
@torch.cuda.amp.custom_bwd
1897
def backward(ctx, grad):
1898
_, _ = ctx.saved_tensors
1899
self.assertFalse(torch.is_autocast_enabled())
1903
str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated."
1906
str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated."
1910
x = torch.randn(3, 3, requires_grad=True)
1911
y = torch.randn(3, 3, requires_grad=True)
1912
with torch.amp.autocast("cuda"):
1917
def test_autocast_cat_jit(self):
1920
class Model(torch.nn.Module):
1924
c = torch.cat((a, b), 0)
1925
d = torch.stack([c, c], 0)
1931
model_jit_script = torch.jit.script(model)
1933
with torch.autocast("cuda", enabled=True):
1940
@unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
1941
def test_autocast_rnn(self):
1942
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
1944
clses = ("RNN", "GRU", "LSTM")
1945
T, B, F, H = 3, 4, 5, 6
1946
dtypes = (torch.float16, torch.float32)
1947
input_layouts = ("seq_first", "batch_first", "packed")
1955
try_nonpreflattened_weights,
1970
if input_layout == "seq_first":
1972
x = torch.randn((T, B, F), device="cuda", dtype=input_dtype)
1973
elif input_layout == "batch_first":
1975
x = torch.randn((B, T, F), device="cuda", dtype=input_dtype)
1976
elif input_layout == "packed":
1978
x = torch.nn.utils.rnn.pack_padded_sequence(
1979
torch.randn((T, B, F), device="cuda", dtype=input_dtype),
1980
lengths=(3, 2, 1, 3),
1981
enforce_sorted=False,
1985
getattr(torch.nn, cls)(
1988
num_layers=num_layers,
1989
bidirectional=bidirectional,
1991
batch_first=batch_first,
1994
.to(dtype=weight_dtype)
1997
if try_nonpreflattened_weights:
1998
for p in rnn.parameters():
1999
with torch.no_grad():
2003
(num_layers * (2 if bidirectional else 1), B, H),
2009
(num_layers * (2 if bidirectional else 1), B, H),
2015
with torch.autocast("cuda"):
2016
out, h_out = rnn(x, h)
2017
out = out.data if input_layout == "packed" else out
2018
self.assertEqual(out.dtype, torch.float16)
2024
"MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0",
2026
out.sum().backward()
2027
grads = [p.grad.clone() for p in rnn.parameters()]
2032
out_control, h_out_control = rnn.to(dtype=torch.float16)(
2033
x.half(), (h[0].half(), h[1].half())
2036
out_control, h_out_control = rnn.to(dtype=torch.float16)(
2040
out_control.data if input_layout == "packed" else out_control
2042
out_control.sum().backward()
2043
grads_control = [p.grad.clone() for p in rnn.parameters()]
2047
self.assertEqual(out, out_control)
2051
h_out[0].dtype is torch.float16
2052
and h_out[1].dtype is torch.float16
2054
self.assertEqual(h_out[0], h_out_control[0])
2055
self.assertEqual(h_out[1], h_out_control[1])
2057
self.assertEqual(h_out.dtype, torch.float16)
2058
self.assertEqual(h_out, h_out_control)
2059
for grad, grad_control in zip(grads, grads_control):
2060
self.assertEqual(grad.half(), grad_control)
2062
def test_autocast_cache_leak(self):
2067
linear = torch.nn.Linear(10, 10).to("cuda")
2068
data = torch.randn(1, 10, device="cuda")
2070
with torch.autocast("cuda"):
2071
with torch.no_grad():
2073
first_iter_mem = torch.cuda.memory_allocated()
2076
self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
2078
def test_autocast_checkpointing(self):
2079
model = torch.nn.Sequential(
2080
torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)
2083
(8, 8), device="cuda", dtype=torch.float16, requires_grad=True
2085
for reentrant in (True, False):
2086
with torch.autocast("cuda"):
2087
output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant)
2088
self.assertTrue(output.requires_grad)
2089
self.assertTrue(output.dtype is torch.float16)
2090
output.sum().backward()
2092
def test_cuda_autocast_deprecated_warning(self):
2093
with self.assertWarnsRegex(
2095
r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.",
2097
with torch.cuda.amp.autocast():
2101
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
2103
def test_max_large_axis(self):
2104
x = torch.zeros(2**32, device="cuda", dtype=torch.int8)
2107
self.assertEqual(val, 1)
2108
self.assertEqual(idx, x.shape[0] - 1)
2110
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2111
def test_to_numpy(self):
2112
self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy())
2114
def test_graph_is_current_stream_capturing(self):
2115
self.assertFalse(torch.cuda.is_current_stream_capturing())
2117
if TEST_CUDA and (not TEST_WITH_ROCM):
2118
s = torch.cuda.Stream()
2119
with torch.cuda.stream(s):
2120
g = torch.cuda.CUDAGraph()
2121
self.assertFalse(torch.cuda.is_current_stream_capturing())
2123
self.assertTrue(torch.cuda.is_current_stream_capturing())
2127
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2129
def test_graph_capture_simple(self):
2130
s = torch.cuda.Stream()
2132
with torch.cuda.stream(s):
2133
a = torch.full((1000,), 1, device="cuda")
2134
g = torch.cuda.CUDAGraph()
2135
torch.cuda.empty_cache()
2141
torch.cuda.current_stream().wait_stream(s)
2145
self.assertTrue(b.sum().item() == 11000.0)
2148
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2150
def test_graphsafe_set_get_rng_state(self):
2152
def create_states(generator):
2153
"""Initializes generator states and registers them with a CUDA graph if provided."""
2155
torch.rand(1, device="cuda")
2156
generator.manual_seed(0)
2159
old_state = generator.graphsafe_get_state()
2161
new_state = generator.clone_state()
2163
return generator, old_state, new_state
2165
def register_states_to_graph(generator_state, graph):
2166
generator, old_state, new_state = generator_state
2167
graph.register_generator_state(old_state)
2168
graph.register_generator_state(new_state)
2171
def perform_random_generation_steps(generator_state):
2172
generator, old_state, new_state = generator_state
2176
generator.graphsafe_set_state(new_state)
2177
random_values.append(torch.rand(5, device="cuda", generator=generator))
2180
generator.graphsafe_set_state(old_state)
2181
random_values.extend(
2182
[torch.rand(5, device="cuda", generator=generator) for _ in range(2)]
2185
return random_values
2188
def get_final_offsets_of_states(generator_state):
2189
generator, old_state, new_state = generator_state
2190
old_state_offset = old_state.get_offset()
2191
new_state_offset = new_state.get_offset()
2192
return old_state_offset, new_state_offset
2195
generator = torch.Generator(device="cuda")
2196
generator_state = create_states(generator)
2199
g = torch.cuda.CUDAGraph()
2200
s = torch.cuda.Stream()
2201
default_generator = torch.cuda.default_generators[0]
2202
default_generator_state = create_states(default_generator)
2203
register_states_to_graph(default_generator_state, g)
2206
with torch.cuda.stream(s):
2208
graphed_random_values = perform_random_generation_steps(
2209
default_generator_state
2214
torch.cuda.current_stream().wait_stream(s)
2216
random_values = perform_random_generation_steps(generator_state)
2218
offset = get_final_offsets_of_states(generator_state)
2219
graph_offset = get_final_offsets_of_states(default_generator_state)
2222
self.assertTrue(offset == graph_offset)
2224
self.assertEqual(random_values, graphed_random_values)
2227
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2229
def test_memory_stats_of_multiple_generators_and_graphs(self):
2231
def clear_cuda_cache():
2233
torch.cuda.empty_cache()
2236
def simple_graph_task(graph):
2237
s = torch.cuda.Stream()
2238
with torch.cuda.stream(s):
2239
graph.capture_begin()
2240
torch.rand(1, device="cuda")
2242
torch.cuda.current_stream().wait_stream(s)
2245
def get_memory_stats():
2246
stats = torch.cuda.memory_stats()
2247
num_blocks = stats["active.all.current"]
2248
total_size = stats["active_bytes.all.current"]
2249
return num_blocks, total_size
2251
def test(num_graphs, num_generators):
2252
baseline = get_memory_stats()
2253
baseline_num_blocks, baseline_total_size = baseline
2256
graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)]
2259
default_generator = torch.cuda.default_generators[0]
2260
generators = [default_generator.graphsafe_get_state()]
2263
for _ in range(1, num_generators):
2264
generators.append(default_generator.clone_state())
2266
for graph in graphs:
2267
for generator_state in generators:
2268
graph.register_generator_state(generator_state)
2269
simple_graph_task(graph)
2272
num_blocks, total_size = get_memory_stats()
2274
expected_blocks_diff = 2 * num_generators
2275
expected_size_diff = 2 * 512 * num_generators
2278
(num_blocks - baseline_num_blocks) == expected_blocks_diff,
2279
"Unexpected number of active blocks.",
2282
(total_size - baseline_total_size) == expected_size_diff,
2283
"Unexpected total memory size.",
2288
graph = graphs.pop()
2294
get_memory_stats() == baseline,
2295
"Memory stats do not match baseline after cleanup.",
2304
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2306
def test_graph_capture_reset_recapture(self):
2307
s = torch.cuda.Stream()
2309
with torch.cuda.stream(s):
2310
a = torch.full((1000,), 1, device="cuda")
2311
g = torch.cuda.CUDAGraph()
2312
torch.cuda.empty_cache()
2318
torch.cuda.current_stream().wait_stream(s)
2322
self.assertTrue(b.sum().item() == 11000.0)
2326
with torch.cuda.stream(s):
2332
torch.cuda.current_stream().wait_stream(s)
2335
self.assertTrue(b.sum().item() == 22000.0)
2341
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2343
def test_graph_debugdump(self):
2344
torch.cuda.empty_cache()
2345
x = torch.randn(10240000, device="cuda")
2346
y = torch.rand_like(x)
2347
g = torch.cuda.CUDAGraph()
2348
g.enable_debug_mode()
2349
s0 = torch.cuda.Stream()
2350
s1 = torch.cuda.Stream()
2351
s0.wait_stream(torch.cuda.current_stream())
2352
with torch.cuda.stream(s0):
2355
with torch.cuda.stream(s1):
2361
torch.cuda.synchronize()
2362
with tempfile.TemporaryDirectory() as tempdir:
2363
g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot"))
2366
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2368
def test_graph_error(self):
2374
g = torch.cuda.CUDAGraph()
2377
except RuntimeError as e:
2378
if "CUDA graphs must be captured on a non-default stream." in str(e):
2385
a = subprocess.check_output(
2386
[sys.executable, "-c", script],
2387
stderr=subprocess.STDOUT,
2390
cwd=os.path.dirname(os.path.realpath(__file__)),
2392
except subprocess.CalledProcessError as e:
2393
if e.returncode == 1:
2396
"Error raise by starting capture without a stream is not the expected one",
2398
elif e.returncode == 2:
2401
"Error raised by starting capture without a stream was not caught",
2405
(not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11,
2406
"CUDA >= 11.0 required for graphs",
2408
def test_graph_warn_if_has_zero_nodes(self):
2409
with warnings.catch_warnings(record=True) as caught:
2410
g = torch.cuda.CUDAGraph()
2411
s = torch.cuda.Stream()
2412
with torch.cuda.stream(s):
2416
any("The CUDA Graph is empty" in str(w.message) for w in caught)
2420
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2423
IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support"
2425
def test_graph_capture_oom(self):
2427
"would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory"
2429
with self.assertRaisesRegex(RuntimeError, oom_regex):
2430
with torch.cuda.graph(torch.cuda.CUDAGraph()):
2431
torch.zeros(2**40, device="cuda")
2434
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2437
def test_repeat_graph_capture_cublas_workspace_memory(self):
2438
(x, y, z) = 1024, 512, 64
2439
a = torch.rand((x, y), device="cuda")
2440
b = torch.rand((y, z), device="cuda")
2445
free_bytes_before, total_bytes = torch.cuda.mem_get_info()
2446
used_gb_before = (total_bytes - free_bytes_before) / 1e9
2448
for i in range(100):
2449
torch_graph = torch.cuda.CUDAGraph()
2450
with torch.cuda.graph(torch_graph):
2452
torch_graph.replay()
2454
free_bytes_after, _ = torch.cuda.mem_get_info()
2455
used_gb_after = (total_bytes - free_bytes_after) / 1e9
2457
self.assertFalse(used_gb_before + 0.1 < used_gb_after)
2460
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2462
def test_graph_rng_functional(self):
2464
(torch.nn.functional.dropout, {"p": 0.1}),
2465
(torch.nn.functional.rrelu, {"training": True}),
2469
def run(op, kwargs):
2470
a = torch.randn((size,), device="cuda", dtype=torch.float)
2473
torch.cuda.manual_seed(5)
2476
eager_out = op(eager_out, **kwargs)
2478
graph_in = a.clone()
2479
stream = torch.cuda.Stream()
2480
stream.wait_stream(torch.cuda.current_stream())
2481
with torch.cuda.stream(stream):
2482
torch.cuda.manual_seed(5)
2484
g = torch.cuda.CUDAGraph()
2485
torch.cuda.empty_cache()
2487
graph_out = graph_in
2489
graph_out = op(graph_out, **kwargs)
2491
torch.cuda.current_stream().wait_stream(stream)
2498
out = op(graph_out, **kwargs)
2499
out = op(out, **kwargs)
2506
self.assertEqual(eager_out, graph_out)
2507
except Exception as e:
2508
raise RuntimeError("Failed on ", op) from e
2511
seeds = [6, 128, 9999]
2514
torch.cuda.manual_seed(seed)
2522
self.assertNotEqual(eager_out, graph_out)
2523
except Exception as e:
2524
raise RuntimeError("Failed on ", op) from e
2527
torch.cuda.manual_seed(seed)
2530
eager_out = op(eager_out, **kwargs)
2531
eager_out = op(eager_out, **kwargs)
2536
self.assertEqual(eager_out, graph_out)
2537
except Exception as e:
2538
raise RuntimeError("Failed on ", op) from e
2542
torch.cuda.synchronize()
2544
for op, kwargs in ops_with_kwargs:
2548
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2550
def test_graph_rng_distributions(self):
2552
input = torch.rand((size,), device="cuda", dtype=torch.float)
2553
alloc = torch.empty((size,), device="cuda", dtype=torch.float)
2557
("bernoulli", (input.clone(),), {}),
2565
("normal", (input.clone() + 1, 1.0), {}),
2566
("poisson", (input.clone(),), {}),
2567
("rand", (size,), {"device": "cuda", "dtype": torch.float}),
2568
("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}),
2569
("randn", (size,), {"device": "cuda", "dtype": torch.float}),
2573
tensor_with_args = (
2574
("bernoulli_", (input.clone(),)),
2576
("exponential_", ()),
2577
("geometric_", (0.3,)),
2578
("log_normal_", ()),
2584
def run(module, op, args, kwargs):
2585
torch.cuda.manual_seed(5)
2588
if module == "torch":
2589
dummy = getattr(torch, op)(*args, **kwargs)
2590
control1 = getattr(torch, op)(*args, **kwargs)
2591
control2 = getattr(torch, op)(*args, **kwargs)
2593
dummy = alloc.clone()
2594
control1 = alloc.clone()
2595
control2 = alloc.clone()
2596
getattr(dummy, op)(*args)
2597
getattr(control1, op)(*args)
2598
getattr(control2, op)(*args)
2600
stream = torch.cuda.Stream()
2601
stream.wait_stream(torch.cuda.current_stream())
2602
with torch.cuda.stream(stream):
2603
torch.cuda.manual_seed(5)
2605
g = torch.cuda.CUDAGraph()
2606
torch.cuda.empty_cache()
2607
if module == "torch":
2609
t1 = getattr(torch, op)(*args, **kwargs)
2610
t2 = getattr(torch, op)(*args, **kwargs)
2616
getattr(t1, op)(*args)
2617
getattr(t2, op)(*args)
2619
torch.cuda.current_stream().wait_stream(stream)
2621
if not TEST_CUDAMALLOCASYNC:
2630
self.assertNotEqual(control1, t1)
2631
self.assertNotEqual(control2, t2)
2632
except Exception as e:
2633
raise RuntimeError("Failed on " + module + "." + op) from e
2636
for seed in [6, 314, 271]:
2637
torch.cuda.manual_seed(seed)
2640
if module == "torch":
2641
dummy = getattr(torch, op)(*args, **kwargs)
2642
control1 = getattr(torch, op)(*args, **kwargs)
2643
control2 = getattr(torch, op)(*args, **kwargs)
2645
getattr(dummy, op)(*args)
2646
getattr(control1, op)(*args)
2647
getattr(control2, op)(*args)
2649
torch.cuda.manual_seed(seed)
2650
if module == "torch":
2651
dummy = getattr(torch, op)(*args, **kwargs)
2653
getattr(dummy, op)(*args)
2656
if not TEST_CUDAMALLOCASYNC:
2664
self.assertEqual(control1, t1)
2665
self.assertEqual(control2, t2)
2666
except Exception as e:
2667
raise RuntimeError("Failed on " + module + "." + op) from e
2671
torch.cuda.synchronize()
2673
for op_with_args in torch_with_args:
2674
run("torch", *op_with_args)
2676
for meth_with_args in tensor_with_args:
2678
run("Tensor", *(meth_with_args + ({},)))
2681
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2683
def test_graph_two_successive(self):
2684
torch.cuda.empty_cache()
2687
kSmallBuffer = 2097152
2689
def func_with_temps(t, val):
2694
s = torch.cuda.Stream()
2696
for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2697
g0 = torch.cuda.CUDAGraph()
2698
g1 = torch.cuda.CUDAGraph()
2700
a = torch.ones((size,), device="cuda")
2702
s.wait_stream(torch.cuda.current_stream())
2703
with torch.cuda.stream(s):
2705
(torch.cuda.graph_pool_handle(),)
2706
if share_mem == "via graph_pool_handle()"
2709
g0.capture_begin(*g0_args)
2712
b = func_with_temps(b, 1)
2715
g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2716
g1.capture_begin(*g1_args)
2718
b = func_with_temps(b, 1)
2720
torch.cuda.current_stream().wait_stream(s)
2725
c = func_with_temps(c, 3)
2728
c = func_with_temps(c, 3)
2731
c = func_with_temps(c, 3)
2733
self.assertEqual(b.sum().item(), size * 3070)
2734
self.assertEqual(c.sum().item(), size * 442)
2736
if not TEST_CUDAMALLOCASYNC:
2738
if share_mem != "Don't share":
2741
- torch.cuda.memory_stats()["reserved_bytes.all.current"],
2745
reserved_no_sharing = torch.cuda.memory_stats()[
2746
"reserved_bytes.all.current"
2751
torch.cuda.synchronize()
2752
torch.cuda.empty_cache()
2755
(not TEST_CUDA_GRAPH)
2759
and int(torch.version.cuda.split(".")[0]) == 11
2760
and int(torch.version.cuda.split(".")[1]) < 4
2762
"Graph bindings disallow concurrent replay for CUDA < 11.4, see "
2763
+ "https://github.com/pytorch/pytorch/pull/57556",
2766
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2768
def test_graph_concurrent_replay(self):
2769
torch.cuda.empty_cache()
2773
def func_with_temps(t, val):
2778
s = torch.cuda.Stream()
2780
for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2781
g0 = torch.cuda.CUDAGraph()
2782
g1 = torch.cuda.CUDAGraph()
2784
s0 = torch.cuda.Stream()
2785
s1 = torch.cuda.Stream()
2787
a = torch.ones((size,), device="cuda")
2789
s.wait_stream(torch.cuda.current_stream())
2790
with torch.cuda.stream(s):
2792
(torch.cuda.graph_pool_handle(),)
2793
if share_mem == "via graph_pool_handle()"
2796
g0.capture_begin(*g0_args)
2799
b = func_with_temps(b, 1)
2802
g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2803
g1.capture_begin(*g1_args)
2806
c = func_with_temps(c, 2)
2812
torch.cuda.synchronize()
2813
with torch.cuda.stream(s0):
2814
torch.cuda._sleep(1000000)
2817
with torch.cuda.stream(s1):
2819
torch.cuda.current_stream().wait_stream(s0)
2820
torch.cuda.current_stream().wait_stream(s1)
2822
if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"):
2825
self.assertNotEqual(b.sum().item(), size * 94)
2826
self.assertNotEqual(c.sum().item(), size * 156)
2832
self.assertEqual(b.sum().item(), size * 94)
2833
self.assertEqual(c.sum().item(), size * 156)
2837
torch.cuda.synchronize()
2838
torch.cuda.empty_cache()
2841
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
2843
def test_graph_three_successive(self):
2844
torch.cuda.empty_cache()
2848
s = torch.cuda.Stream()
2850
for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"):
2851
a = torch.ones((size,), device="cuda")
2853
g0 = torch.cuda.CUDAGraph()
2854
g1 = torch.cuda.CUDAGraph()
2855
g2 = torch.cuda.CUDAGraph()
2857
s.wait_stream(torch.cuda.current_stream())
2858
with torch.cuda.stream(s):
2860
(torch.cuda.graph_pool_handle(),)
2861
if share_mem == "via graph_pool_handle()"
2864
g0.capture_begin(*g0_args)
2870
args = (g0.pool(),) if share_mem == "via pool()" else g0_args
2872
g1.capture_begin(*args)
2877
g2.capture_begin(*args)
2880
torch.cuda.current_stream().wait_stream(s)
2887
self.assertEqual(e.sum().item(), size * 5)
2888
self.assertEqual(f.sum().item(), size * 7)
2895
expect_corruption = (not TEST_CUDAMALLOCASYNC) and (
2896
share_mem != "Don't share"
2901
e.sum().item(), size * (7 + 3) if expect_corruption else size * 5
2903
self.assertEqual(f.sum().item(), size * 7)
2905
del a, b, d, e, f, g0, g1, g2
2907
torch.cuda.synchronize()
2908
torch.cuda.empty_cache()
2911
(not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC,
2912
"CUDA >= 11.0 or ROCM >= 5.3 required for graphs",
2914
def test_graph_memory_stats_and_use_result_after_destroy_graph(self):
2915
kSmallSize = 1048576
2916
kSmallBuffer = 2097152
2917
kLargeBuffer = 20971520
2918
kMinLargeAlloc = 10485760
2919
kRoundLarge = 2097152
2925
(512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"),
2926
(kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"),
2927
((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"),
2929
(kMinLargeAlloc - 512) // elem,
2936
(kMinLargeAlloc + 512) // elem,
2941
* ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge)
2943
kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge),
2948
stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.")
2951
torch.cuda.empty_cache()
2953
s = torch.cuda.Stream()
2958
delta_cudaMalloc_bytes,
2959
delta_cudaMalloc_bytes_post_del_g,
2962
if pool_string == "small_pool":
2963
delta_active_blocks = 3
2964
delta_active_bytes = (
2968
delta_active_blocks = 1
2969
delta_active_bytes = numel * elem
2971
g = torch.cuda.CUDAGraph()
2972
s.wait_stream(torch.cuda.current_stream())
2973
with torch.cuda.stream(s):
2977
a = torch.ones((numel,), device="cuda")
2979
precapture_stats = torch.cuda.memory_stats()
2986
torch.cuda.current_stream().wait_stream(s)
2990
postcapture_stats = torch.cuda.memory_stats()
2994
delta_cudaMalloc_bytes,
2995
delta_active_blocks,
3000
for stat, expected in zip(stats_to_check, expecteds):
3001
stat = stat + pool_string + ".current"
3002
current = postcapture_stats[stat] - precapture_stats[stat]
3006
if self.expandable_segments and "segment" in stat:
3012
self.expandable_segments
3013
and "reserved" in stat
3014
and (numel == cases[3][0] or numel == cases[4][0])
3016
expected = 2 * kLargeBuffer
3021
"Pre to post capture delta of "
3023
+ f" = {current}, expected = {expected}, numel = {numel}",
3027
self.assertEqual(b.sum().item(), 6 * numel)
3029
torch.cuda.empty_cache()
3033
torch.cuda.empty_cache()
3034
postdel_stats = torch.cuda.memory_stats()
3037
self.assertEqual(b.sum().item(), 6 * numel)
3040
expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem)
3041
for stat, expected in zip(stats_to_check, expecteds):
3042
stat = stat + pool_string + ".current"
3043
current = postdel_stats[stat] - precapture_stats[stat]
3047
if self.expandable_segments and "segment" in stat:
3053
self.expandable_segments
3054
and "reserved" in stat
3055
and numel == cases[3][0]
3057
expected = 2 * kLargeBuffer
3059
self.expandable_segments
3060
and "reserved" in stat
3061
and numel == cases[4][0]
3063
expected = kLargeBuffer
3068
"Pre capture to post graph delete delta of "
3070
+ f" = {current}, expected = {expected}, numel = {numel}",
3077
torch.cuda.synchronize()
3078
torch.cuda.empty_cache()
3081
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3083
def test_graph_record_stream(self):
3086
torch.cuda.empty_cache()
3088
potential_problem = torch.zeros((3,), device="cuda")
3089
a = torch.zeros((3,), device="cuda")
3090
s0 = torch.cuda.Stream()
3091
s1 = torch.cuda.Stream()
3092
s2 = torch.cuda.Stream()
3093
g = torch.cuda.CUDAGraph()
3095
torch.cuda.synchronize()
3096
with torch.cuda.stream(s0):
3097
potential_problem.record_stream(s0)
3098
torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES)
3099
potential_problem.fill_(1.0)
3100
del potential_problem
3102
with torch.cuda.stream(s1):
3111
with torch.cuda.stream(s2):
3117
torch.cuda.synchronize()
3120
c = torch.zeros((3,), device="cuda")
3124
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3129
@skipCUDAMemoryLeakCheckIf(True)
3131
def test_graph_cudnn_dropout(self):
3135
torch.cuda.empty_cache()
3137
model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda()
3138
x = torch.ones(100, 192, 512, device="cuda")
3142
g = torch.cuda.CUDAGraph()
3143
s = torch.cuda.Stream()
3144
s.wait_stream(torch.cuda.current_stream())
3145
with torch.cuda.stream(s):
3149
torch.cuda.current_stream().wait_stream(s)
3156
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3159
"with_amp,cache_enabled,allow_unused_input",
3161
subtest((False, False, True), decorators=[skipIfRocm]),
3162
subtest((True, False, True), decorators=[skipIfRocm]),
3163
subtest((True, True, True), decorators=[unittest.expectedFailure]),
3164
subtest((False, False, False), decorators=[unittest.expectedFailure]),
3166
name_fn=lambda x, y, z: "{}{}{}".format(
3167
{True: "with_amp", False: "without_amp"}[x],
3168
{True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "",
3169
{True: "_allow_unused_input", False: "_not_allow_unused_input"}[z],
3173
def test_graph_make_graphed_callables(
3174
self, with_amp, cache_enabled, allow_unused_input
3176
torch.manual_seed(5)
3177
torch.cuda.manual_seed(5)
3179
N, D_in, H, D_out = 640, 4096, 2048, 1024
3181
class MLP1(torch.nn.Module):
3182
def __init__(self, D_in: int, H: int, D_out: int):
3184
self.net_1 = torch.nn.Sequential(
3185
torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1)
3187
self.net_2 = torch.nn.Sequential(
3188
torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2)
3191
def forward(self, input_dict: dict):
3193
return self.net_2(self.net_1(x))
3195
class MLP2(torch.nn.Module):
3196
def __init__(self, D_in: int, H: int, D_out: int):
3198
self.net_1 = torch.nn.Sequential(
3199
torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1)
3201
self.net_2 = torch.nn.Sequential(
3202
torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2)
3205
def forward(self, x):
3206
return self.net_2(self.net_1(x))
3208
class ParameterlessModule(torch.nn.Module):
3209
def forward(self, x):
3211
torch.arange(x.size(0), device=x.device)
3213
.repeat(1, x.size(1))
3215
return {"output": torch.gather(x, 0, idx)}
3219
model_section1 = MLP1(D_in, H, H).cuda()
3220
model_section2 = MLP2(H, H, D_out).cuda()
3221
model_section3 = ParameterlessModule().cuda()
3223
torch.nn.Sequential(model_section1, model_section2, model_section3)
3226
model_graphed = models[0]
3227
model_control = models[1]
3229
model_graphed.load_state_dict(model_control.state_dict())
3231
opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1)
3232
opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1)
3234
x = torch.randn(N, D_in, device="cuda")
3235
h = torch.randn(N, H, device="cuda", requires_grad=True)
3236
h2 = torch.randn(N, D_out, device="cuda", requires_grad=True)
3237
unused_input = torch.randn(N, H, device="cuda", requires_grad=True)
3238
y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True)
3239
y = torch.randn(N, D_out, device="cuda")
3241
loss_fn_control = torch.nn.functional.mse_loss
3242
relu_control = torch.nn.functional.relu
3245
with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
3252
) = torch.cuda.make_graphed_callables(
3261
({"x": x, "unused_input": unused_input},),
3267
allow_unused_input=allow_unused_input,
3270
real_inputs = [torch.rand_like(x) for _ in range(10)]
3271
real_targets = [torch.rand_like(y) for _ in range(10)]
3273
for m, opt, relu, loss_fn in zip(
3274
(model_graphed, model_control),
3275
(opt_graphed, opt_control),
3276
(relu_graphed, relu_control),
3277
(loss_fn_graphed, loss_fn_control),
3281
torch.manual_seed(5)
3282
torch.cuda.manual_seed(5)
3283
for data, target in zip(real_inputs, real_targets):
3284
opt.zero_grad(set_to_none=True)
3285
with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
3286
y_pred = m({"x": data, "unused_input": unused_input})["output"]
3287
y_pred = relu(y_pred)
3288
loss = loss_fn(y_pred, target)
3292
for p, pc in zip(model_graphed.parameters(), model_control.parameters()):
3293
self.assertEqual(p, pc)
3296
model_graphed.eval()
3297
model_control.eval()
3299
model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
3303
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3306
"with_amp,cache_enabled,allow_unused_input",
3308
subtest((False, False, True), decorators=[skipIfRocm]),
3309
subtest((True, False, True), decorators=[skipIfRocm]),
3310
subtest((True, True, True), decorators=[unittest.expectedFailure]),
3311
subtest((False, False, False), decorators=[skipIfRocm]),
3313
name_fn=lambda x, y, z: "{}{}{}".format(
3314
{True: "with_amp", False: "without_amp"}[x],
3315
{True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "",
3316
{True: "_allow_unused_input", False: "_not_allow_unused_input"}[z],
3320
def test_graph_make_graphed_callables_parameterless_nograd_module(
3321
self, with_amp, cache_enabled, allow_unused_input
3323
torch.manual_seed(5)
3324
torch.cuda.manual_seed(5)
3326
N, D_in, H, D_out = 640, 4096, 2048, 1024
3328
class ParameterlessModule(torch.nn.Module):
3329
def forward(self, input_dict: dict):
3332
torch.arange(x.size(0), device=x.device)
3334
.repeat(1, x.size(1))
3336
return {"output": torch.gather(x, 0, idx)}
3340
model_section1 = ParameterlessModule().cuda()
3341
models.append(torch.nn.Sequential(model_section1))
3343
model_graphed = models[0]
3344
model_control = models[1]
3346
model_graphed.load_state_dict(model_control.state_dict())
3348
x = torch.randn(N, D_in, device="cuda", requires_grad=False)
3349
unused_input = torch.randn(N, H, device="cuda", requires_grad=False)
3350
y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False)
3351
y = torch.randn(N, D_in, device="cuda")
3354
with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
3355
model_graphed[0] = torch.cuda.make_graphed_callables(
3357
({"x": x, "unused_input": unused_input},),
3358
allow_unused_input=allow_unused_input,
3361
real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)]
3362
real_targets = [torch.rand_like(y) for _ in range(10)]
3364
for m in (model_graphed, model_control):
3367
torch.manual_seed(5)
3368
torch.cuda.manual_seed(5)
3369
for data, target in zip(real_inputs, real_targets):
3370
with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
3371
out = m({"x": data, "unused_input": unused_input})["output"]
3374
model_graphed.eval()
3375
model_control.eval()
3377
model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
3381
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3383
def test_graph_make_graphed_callables_same_pool(self):
3384
torch.manual_seed(5)
3385
torch.cuda.manual_seed(5)
3388
for _ in range(num_models):
3390
torch.nn.Sequential(
3391
torch.nn.Linear(32, 128),
3393
torch.nn.Linear(128, 128),
3397
mempool = torch.cuda.graph_pool_handle()
3399
for model in models:
3400
x = torch.randn([64, 32], device="cuda")
3401
graphed_model = deepcopy(model)
3402
graphed_model = torch.cuda.make_graphed_callables(
3403
graphed_model, (x,), pool=mempool
3405
graphed_models.append(graphed_model)
3407
for model, graphed_model in zip(models, graphed_models):
3408
x = torch.randn([64, 32], device="cuda")
3410
yg = graphed_model(x)
3416
self.assertEqual(y, yg)
3417
self.assertEqual(l, lg)
3418
for p, pg in zip(model.parameters(), graphed_model.parameters()):
3419
self.assertEqual(p, pg)
3420
self.assertEqual(p.grad, pg.grad)
3421
self.assertNotEqual(p.data_ptr(), pg.data_ptr())
3422
self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr())
3424
def _test_graphed_optimizer(
3425
self, steps_warmup, steps_train, optimizer_ctor, kwargs
3427
for actually_do_graphs in (True, False):
3428
params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [
3429
torch.randn((), device="cuda")
3431
params_control = [p.clone().requires_grad_() for p in params]
3432
params_graphed = [p.clone().requires_grad_() for p in params]
3435
[torch.randn_like(p) for p in params]
3436
for _ in range(steps_warmup + steps_train)
3441
opt = optimizer_ctor(params_control, capturable=False, **kwargs)
3443
for i in range(steps_warmup + steps_train):
3444
for j, p in enumerate(params_control):
3445
p.grad = grads[i][j]
3450
opt = optimizer_ctor(params_graphed, capturable=True, **kwargs)
3452
for i in range(steps_warmup):
3453
for j, p in enumerate(params_graphed):
3454
p.grad = grads[i][j]
3457
if actually_do_graphs:
3458
g = torch.cuda.CUDAGraph()
3459
with torch.cuda.graph(g):
3462
for i in range(steps_train):
3463
if actually_do_graphs:
3464
for j, p in enumerate(params_graphed):
3465
p.grad.copy_(grads[i + steps_warmup][j])
3470
for j, p in enumerate(params_graphed):
3471
p.grad = grads[i + steps_warmup][j]
3474
for p_control, p_graphed in zip(params_control, params_graphed):
3475
self.assertEqual(p_control, p_graphed)
3478
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3480
def test_graph_optims_with_explicitly_capturable_param_groups(self):
3482
n_warmup, n_replay = 3, 2
3483
for optimizer, second_param_group_capturable in product(
3491
torch.optim.Adadelta,
3492
torch.optim.RMSprop,
3498
torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2)
3501
torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2)
3504
[torch.randn_like(param1) for _ in range(n_warmup + n_replay)]
3507
ref_grads1, ref_grads2 = (
3508
[t.clone() for t in tensors] for tensors in (grads1, grads2)
3511
{"params": [param1], "capturable": True},
3512
{"params": [param2], "capturable": second_param_group_capturable},
3514
opt = optimizer(params)
3517
{"params": [ref_p1], "capturable": False},
3518
{"params": [ref_p2], "capturable": False},
3522
for i in range(n_warmup + n_replay):
3523
ref_p1.grad = ref_grads1[i]
3524
ref_p2.grad = ref_grads2[i]
3527
for i in range(n_warmup):
3528
param1.grad = grads1[i]
3529
param2.grad = grads2[i]
3532
g = torch.cuda.CUDAGraph()
3533
if not second_param_group_capturable:
3534
with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"):
3535
with torch.cuda.graph(g):
3538
with torch.cuda.graph(g):
3541
for i in range(n_replay):
3542
param1.grad.copy_(grads1[n_warmup + i])
3543
param2.grad.copy_(grads2[n_warmup + i])
3545
self.assertEqual(ref_p1, param1)
3546
self.assertEqual(ref_p2, param2)
3549
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3551
def test_cuda_graph_error_options(self):
3553
x = torch.zeros([2000], device="cuda")
3562
stream = torch.cuda.Stream()
3564
with torch.cuda.stream(stream):
3565
mem = torch.cuda.caching_allocator_alloc(1024)
3566
except BaseException:
3570
torch.cuda.caching_allocator_delete(mem)
3573
except BaseException:
3576
def throws_on_cuda_event(capture_error_mode):
3577
graph = torch.cuda.CUDAGraph()
3578
torch.cuda.synchronize()
3579
stream = torch.cuda.Stream()
3580
stream.wait_stream(torch.cuda.current_stream())
3581
with torch.cuda.stream(stream):
3583
stream.synchronize()
3584
torch.cuda.current_stream().wait_stream(stream)
3585
torch.cuda.synchronize()
3587
with torch.cuda.graph(
3588
graph, stream=stream, capture_error_mode=capture_error_mode
3591
thread = threading.Thread(target=raw_malloc)
3596
torch.cuda.caching_allocator_delete(mem)
3601
self.assertFalse(throws_on_cuda_event("thread_local"))
3602
self.assertFalse(throws_on_cuda_event("relaxed"))
3608
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
3610
def test_cuda_graph_allocator_propagates_stream(self):
3611
segments = torch.cuda.memory_snapshot()
3612
existing_pools = {s["segment_pool_id"] for s in segments}
3613
x = torch.randn(10240000, device="cuda")
3614
y = torch.rand_like(x)
3615
g = torch.cuda.CUDAGraph()
3616
s0 = torch.cuda.Stream()
3617
s1 = torch.cuda.Stream()
3618
s0.wait_stream(torch.cuda.current_stream())
3619
with torch.cuda.stream(s0):
3622
with torch.cuda.stream(s1):
3626
with torch.cuda.stream(s0):
3628
segments = torch.cuda.memory_snapshot()
3630
s["segment_pool_id"]
3632
if s["segment_pool_id"] not in existing_pools
3634
self.assertEqual(len(x), 2)
3635
self.assertEqual(x[0], x[1])
3637
def test_batch_norm_gather_stats(self):
3638
input = torch.randn(1, 3, 3, 3, device="cuda")
3639
mean, invstd = torch.batch_norm_gather_stats(
3641
mean=torch.ones(2, 3, device="cuda"),
3642
invstd=torch.ones(2, 3, device="cuda"),
3649
self.assertEqual(mean, torch.ones(3, device="cuda"))
3650
self.assertEqual(invstd, torch.ones(3, device="cuda"))
3652
def test_matmul_memory_use(self):
3654
torch.cuda.synchronize()
3655
val = torch.cuda.max_memory_allocated()
3656
torch.cuda.reset_peak_memory_stats()
3659
a = torch.rand(1, 32, 32, device="cuda")
3660
b = torch.rand(24, 32, 1, device="cuda")
3666
matmul_mem = get_max_used()
3668
a = a.expand(24, 32, 32)
3671
matmul_expand_mem = get_max_used()
3675
bmm_mem = get_max_used()
3677
self.assertEqual(matmul_expand_mem, matmul_mem)
3678
self.assertEqual(bmm_mem, matmul_mem)
3680
@unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test")
3681
def test_rocm_backward_pass_guard(self):
3684
class MyFunction(torch.autograd.Function):
3686
def forward(ctx, tensor, constant):
3687
self.assertFalse(torch._C._rocm_is_backward_pass())
3688
ctx.constant = constant
3689
return tensor * constant
3692
def backward(ctx, grad_output):
3693
self.assertTrue(torch._C._rocm_is_backward_pass())
3694
return grad_output * ctx.constant, None
3696
class MyModule(torch.nn.Module):
3697
def __init__(self) -> None:
3699
self.a = torch.nn.Parameter(torch.randn(()))
3701
def forward(self, x):
3702
return MyFunction.apply(x, self.a)
3705
criterion = torch.nn.MSELoss(reduction="sum")
3706
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
3708
x = torch.randn(5, 5)
3710
loss = criterion(result, x)
3711
optimizer.zero_grad()
3715
def test_matmul_device_mismatch(self):
3716
cpu = torch.rand((10, 10))
3718
with self.assertRaisesRegex(
3719
RuntimeError, "Expected all tensors to be on the same device"
3722
with self.assertRaisesRegex(
3723
RuntimeError, "Expected all tensors to be on the same device"
3727
for s, m1, m2 in product((cpu, cuda), repeat=3):
3728
if s.device == m1.device == m2.device:
3729
torch.addmm(s, m1, m2)
3731
with self.assertRaisesRegex(
3732
RuntimeError, "Expected all tensors to be on the same device"
3734
torch.addmm(s, m1, m2)
3736
@unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient")
3737
def test_lazy_init(self):
3738
"""Validate that no CUDA calls are made during `import torch` call"""
3740
def check_output(script: str) -> str:
3742
subprocess.check_output([sys.executable, "-c", script])
3748
"HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES"
3750
test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())"
3751
rc = check_output(test_script)
3752
self.assertEqual(rc, "0")
3753
if not TEST_WITH_ROCM:
3757
libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll"
3758
cuda_driver_api_call = (
3759
f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))"
3762
f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})"
3764
self.assertEqual(rc, "3")
3766
@unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing")
3767
def test_hip_device_count(self):
3768
"""Validate device_count works with both CUDA/HIP visible devices"""
3772
print(f"{torch.cuda.device_count()}")
3775
{"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None},
3776
{"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"},
3777
{"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"},
3780
for env_config in custom_envs:
3781
env = os.environ.copy()
3782
for key, value in env_config.items():
3788
subprocess.check_output([sys.executable, "-c", test_script], env=env)
3792
self.assertEqual("1", r)
3794
@unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices")
3795
def test_device_count_not_cached_pre_init(self):
3797
"HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES"
3802
r1 = torch.cuda.device_count()
3803
os.environ['{visible_devices}'] = '0'
3804
r2 = torch.cuda.device_count()
3805
torch.empty(10, device='cuda')
3806
print(f"{{r1}}, {{r2}}")
3810
subprocess.check_output([sys.executable, "-c", test_script])
3815
x = torch.cuda.device_count()
3816
self.assertEqual(f"{x}, 1", r)
3818
@unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
3819
def test_gds_fails_in_ci(self):
3820
if IS_WINDOWS or TEST_WITH_ROCM:
3821
error_msg = "is not supported on this platform"
3823
error_msg = "cuFileHandleRegister failed"
3824
with TemporaryFileName() as f:
3825
with self.assertRaisesRegex(RuntimeError, error_msg):
3826
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
3829
@torch.testing._internal.common_utils.markDynamoStrictTest
3830
class TestCudaMallocAsync(TestCase):
3832
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3834
def test_memory_snapshot(self):
3836
torch.cuda.memory.empty_cache()
3837
torch.cuda.memory._record_memory_history("state", stacks="python")
3839
torch.rand(2 * 311, 411, device="cuda")
3840
unused = torch.rand(310, 410, device="cuda")
3841
x = torch.rand(311, 411, device="cuda")
3848
tensors = [torch.rand(128, device="cuda") for _ in range(1000)]
3850
del tensors[randint(0, len(tensors) - 1)]
3853
torch.rand(128 * 5, device="cuda")
3855
ss = torch.cuda.memory._snapshot()
3857
for seg in ss["segments"]:
3858
self.assertTrue("frames" in seg)
3859
for b in seg["blocks"]:
3860
if b["requested_size"] == 311 * 411 * 4:
3861
self.assertTrue("test_cuda" in b["frames"][0]["filename"])
3863
self.assertEqual(x.untyped_storage().data_ptr(), b["address"])
3864
self.assertTrue(found_it)
3867
with tempfile.NamedTemporaryFile() as f:
3868
torch.cuda.memory._save_segment_usage(f.name)
3869
with open(f.name) as f2:
3870
self.assertTrue("test_cuda.py" in f2.read())
3873
torch.cuda.empty_cache()
3874
ss = torch.cuda.memory._snapshot()
3876
ss["device_traces"][0][-1]["action"]
3877
in ("segment_free", "segment_unmap")
3881
torch.cuda.memory._record_memory_history(None)
3883
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
3884
def test_direct_traceback(self):
3885
from torch._C._profiler import gather_traceback, symbolize_tracebacks
3887
c = gather_traceback(True, True, True)
3888
(r,) = symbolize_tracebacks([c])
3890
self.assertTrue("test_cuda.py" in r)
3891
self.assertTrue("unwind" in r)
3894
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3896
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3897
def test_memory_snapshot_with_cpp(self):
3899
torch.cuda.memory.empty_cache()
3900
torch.cuda.memory._record_memory_history("state", stacks="all")
3901
x = torch.rand(311, 411, device="cuda")
3903
ss = torch.cuda.memory._snapshot()["segments"]
3906
for b in seg["blocks"]:
3907
if b["requested_size"] == 311 * 411 * 4:
3908
self.assertTrue("::rand" in str(b["frames"]))
3910
self.assertTrue(found_it)
3913
torch.cuda.memory._record_memory_history(None)
3916
def test_memory_profiler_viz(self):
3917
with torch.profiler.profile(
3918
with_stack=True, profile_memory=True, record_shapes=True
3920
x = torch.rand(128, 128, device="cuda")
3922
plot = profile_plot(prof)
3923
plot = json.dumps(_profile_to_snapshot(prof))
3924
self.assertTrue("test_cuda.py" in plot)
3925
self.assertTrue("test_memory_profiler_viz" in plot)
3926
self.assertTrue("category" in plot)
3929
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3931
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3932
def test_cycles(self):
3938
self.assertTrue("torch.Tensor" in html)
3939
self.assertTrue("test_cuda" in html)
3940
self.assertTrue("cell_contents" in html)
3942
disarm = observe_tensor_cycles(observer)
3950
x = torch.empty(3, 4, device="cuda")
3966
self.assertTrue(fired)
3971
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
3973
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
3974
def test_memory_plots(self):
3975
for context, stacks in (
3976
("all", "all" if IS_LINUX else "python"),
3981
torch.cuda.memory.empty_cache()
3982
torch.cuda.memory._record_memory_history(
3983
"all", context=context, stacks=stacks
3987
x = torch.rand(128, 128, device="cuda")
3991
cpp = stacks == "all"
3992
record_context = context is not None
3993
ss = torch.cuda.memory._snapshot()
3995
tplot = trace_plot(ss)
3996
splot = segment_plot(ss)
3997
text = json.dumps(ss)
3999
self.assertTrue(record_context == ("test_memory_plots" in text))
4000
self.assertTrue(cpp == ("::rand" in text))
4001
self.assertTrue(str(128 * 128 * 4) in text)
4004
torch.cuda.memory._record_memory_history(None)
4007
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4009
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
4010
def test_memory_plots_free_stack(self):
4011
for context in ["alloc", "all", "state"]:
4013
torch.cuda.memory.empty_cache()
4014
torch.cuda.memory._record_memory_history(context=context)
4019
x = torch.rand(3, 4, device="cuda")
4027
ss = json.dumps(torch.cuda.memory._snapshot())
4028
self.assertTrue(("thefree" in ss) == (context == "all"))
4029
self.assertTrue(("thealloc" in ss) == (context != "state"))
4031
torch.cuda.memory._record_memory_history(None)
4034
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4036
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
4037
def test_memory_plots_history_context(self):
4039
torch.cuda.memory.empty_cache()
4042
def should_capture1():
4044
x = torch.rand(4, 4, device="cuda")
4046
def should_not_capture():
4048
x = torch.rand(3, 4, device="cuda")
4050
def should_capture2():
4052
x = torch.rand(4, 4, device="cuda")
4055
torch.cuda.memory._record_memory_history(context="all", stacks="python")
4058
torch.cuda.memory._record_memory_history(context=None)
4059
should_not_capture()
4061
torch.cuda.memory._record_memory_history(context="all", stacks="python")
4064
ss = json.dumps(torch.cuda.memory._snapshot())
4065
self.assertTrue("should_capture1" in ss)
4066
self.assertTrue("should_not_capture" not in ss)
4067
self.assertTrue("should_capture2" in ss)
4069
torch.cuda.memory._record_memory_history(None)
4072
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4074
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only")
4075
def test_memory_plots_free_segment_stack(self):
4076
for context in ["alloc", "all", "state"]:
4078
torch.cuda.memory.empty_cache()
4079
torch.cuda.memory._record_memory_history(context=context)
4080
x = torch.rand(3, 4, device="cuda")
4082
torch.cuda.memory.empty_cache()
4084
ss = json.dumps(torch.cuda.memory._snapshot())
4085
self.assertTrue(("empty_cache" in ss) == (context == "all"))
4087
torch.cuda.memory._record_memory_history(None)
4090
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4092
def test_memory_snapshot_script(self):
4094
torch.cuda.memory.empty_cache()
4095
torch.cuda.memory._record_memory_history("state", stacks="python")
4099
return torch.rand(311, 411, device="cuda")
4103
ss = torch.cuda.memory._snapshot()["segments"]
4106
for b in seg["blocks"]:
4107
if b["requested_size"] == 311 * 411 * 4:
4108
self.assertTrue(b["frames"][0]["name"] == "foo")
4110
self.assertTrue(found_it)
4113
torch.cuda.memory._record_memory_history(None)
4115
def test_max_split_expandable(self):
4116
torch.cuda.memory.empty_cache()
4118
_, all_memory = torch.cuda.memory.mem_get_info()
4119
total_allowed = 120 * mb
4120
fraction_allowed = total_allowed / all_memory
4121
assert int(fraction_allowed * all_memory) == total_allowed
4122
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
4125
return torch.ones(n * mb, dtype=torch.int8, device="cuda")
4127
torch.cuda.memory._set_allocator_settings(
4128
"expandable_segments:False,max_split_size_mb:40"
4131
torch.cuda.memory._set_allocator_settings(
4132
"expandable_segments:True,max_split_size_mb:40"
4135
torch.cuda.memory._set_allocator_settings(
4136
"expandable_segments:False,max_split_size_mb:40"
4139
with self.assertRaises(torch.OutOfMemoryError):
4145
def test_garbage_collect_expandable(self):
4146
torch.cuda.memory.empty_cache()
4148
_, all_memory = torch.cuda.memory.mem_get_info()
4149
total_allowed = 120 * mb
4150
fraction_allowed = total_allowed / all_memory
4151
assert int(fraction_allowed * all_memory) == total_allowed
4152
torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed)
4155
return torch.ones(n * mb, dtype=torch.int8, device="cuda")
4157
torch.cuda.memory._set_allocator_settings(
4158
"expandable_segments:False,garbage_collection_threshold:0.5"
4161
torch.cuda.memory._set_allocator_settings(
4162
"expandable_segments:True,garbage_collection_threshold:0.5"
4171
def test_allocator_settings(self):
4172
def power2_div(size, div_factor):
4178
step = pow2 / 2 / div_factor
4184
torch.cuda.memory.empty_cache()
4186
"active_bytes.all.allocated"
4187
if not TEST_CUDAMALLOCASYNC
4188
else "allocated_bytes.all.current"
4190
key_requested = "requested_bytes.all.allocated"
4192
nelems = 21 * 1024 * 1024
4195
nelems_big = 100 * 1024 * 1024
4196
nbytes_big = 4 * nelems_big
4198
start_mem = torch.cuda.memory_stats()[key_allocated]
4199
torch.cuda.memory._set_allocator_settings("")
4200
x = torch.rand(nelems, device="cuda")
4203
reg_mem = torch.cuda.memory_stats()[key_allocated]
4204
start_requested = torch.cuda.memory_stats()[key_requested]
4205
torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4")
4206
y = torch.rand(nelems, device="cuda")
4208
pow2_div4_mem = torch.cuda.memory_stats()[key_allocated]
4209
current_requested = torch.cuda.memory_stats()[key_requested]
4211
self.assertTrue(reg_mem - start_mem == nbytes)
4212
if not TEST_CUDAMALLOCASYNC:
4214
self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4))
4215
self.assertTrue(current_requested - start_requested == nbytes)
4217
torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5")
4218
torch.cuda.memory._set_allocator_settings(
4219
"garbage_collection_threshold:0.5,max_split_size_mb:40"
4223
torch.cuda.memory.empty_cache()
4224
start_mem = torch.cuda.memory_stats()[key_allocated]
4225
z = torch.rand(nelems, device="cuda")
4226
reg_mem = torch.cuda.memory_stats()[key_allocated]
4227
self.assertTrue(reg_mem - start_mem == nbytes)
4230
torch.cuda.memory.empty_cache()
4231
torch.cuda.memory._set_allocator_settings(
4232
"garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]"
4234
start_mem = torch.cuda.memory_stats()[key_allocated]
4235
w = torch.rand(nelems, device="cuda")
4237
pow2_div8_mem = torch.cuda.memory_stats()[key_allocated]
4238
if not TEST_CUDAMALLOCASYNC:
4240
self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8))
4242
torch.cuda.memory.empty_cache()
4243
start_mem = torch.cuda.memory_stats()[key_allocated]
4244
v = torch.rand(nelems_big, device="cuda")
4246
pow2_div2_mem = torch.cuda.memory_stats()[key_allocated]
4247
if not TEST_CUDAMALLOCASYNC:
4249
self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2))
4251
torch.cuda.memory.empty_cache()
4252
torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True")
4253
start_mem = torch.cuda.memory_stats()[key_allocated]
4254
w = torch.rand(nelems, device="cuda")
4255
reg_mem = torch.cuda.memory_stats()[key_allocated]
4256
self.assertTrue(reg_mem - start_mem == nbytes)
4258
with self.assertRaises(RuntimeError):
4259
torch.cuda.memory._set_allocator_settings("foo:1,bar:2")
4261
with self.assertRaises(RuntimeError):
4262
torch.cuda.memory._set_allocator_settings(
4263
"garbage_collection_threshold:1.2"
4266
with self.assertRaises(RuntimeError):
4267
torch.cuda.memory._set_allocator_settings("max_split_size_mb:2")
4269
with self.assertRaises(RuntimeError):
4270
torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none")
4272
with self.assertRaises(RuntimeError):
4273
torch.cuda.memory._set_allocator_settings(
4274
"pinned_use_cuda_host_register:none"
4277
with self.assertRaises(RuntimeError):
4278
torch.cuda.memory._set_allocator_settings(
4279
"pinned_num_register_threads:none"
4282
with self.assertRaises(RuntimeError):
4283
torch.cuda.memory._set_allocator_settings(
4284
"pinned_num_register_threads:1024"
4287
@parametrize("max_split_size_mb_setting", [False, True])
4288
def test_raises_oom(self, max_split_size_mb_setting):
4289
if max_split_size_mb_setting:
4293
torch.cuda.memory._set_allocator_settings("max_split_size_mb:1024")
4294
torch.cuda.memory.empty_cache()
4295
with self.assertRaises(torch.cuda.OutOfMemoryError):
4296
torch.empty(1024 * 1024 * 1024 * 1024, device="cuda")
4299
not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux"
4302
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
4304
def test_cpp_memory_snapshot_pickle(self):
4305
from torch.utils.cpp_extension import load_inline
4308
#include <torch/csrc/cuda/memory_snapshot.h>
4309
py::object do_snapshot() {
4310
std::string data = torch::cuda::_memory_snapshot_pickled();
4311
return py::bytes(data);
4313
void record(bool e, bool ctx) {
4314
torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx);
4318
name="snapshot", cpp_sources=[source], functions=["do_snapshot", "record"]
4320
for ctx in (False, True):
4325
def the_script_fn():
4326
return torch.rand(311, 411, device="cuda")
4330
return pickle.loads(m.do_snapshot())
4334
for s in mem["segments"]:
4335
for b in s["blocks"]:
4336
if b["state"] == "active_allocated":
4337
if b["requested_size"] == 311 * 411 * 4:
4339
frame_text = str(b["frames"])
4341
self.assertTrue("::rand" in frame_text)
4343
self.assertTrue("the_script_fn" in frame_text)
4345
self.assertTrue("case.py" in frame_text)
4347
last_action = mem["device_traces"][0][-1]
4348
self.assertTrue(last_action["action"] == "alloc")
4349
self.assertTrue(last_action["size"] == 311 * 411 * 4)
4350
self.assertTrue(found)
4352
m.record(False, False)
4354
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
4355
def test_notifies_oom(self):
4358
def cb(device, alloc, device_alloc, device_free):
4362
torch._C._cuda_attach_out_of_memory_observer(cb)
4363
with self.assertRaises(torch.cuda.OutOfMemoryError):
4364
torch.empty(1024 * 1024 * 1024 * 1024, device="cuda")
4367
def test_allocator_fuzz(self):
4369
state = random.getstate()
4379
b = random.randrange(2 * 1024 * 1024 // 4, 20 * 1024 * 1024 // 4)
4380
mem.append((c, torch.full((b,), c, dtype=torch.int32, device="cuda")))
4386
idx = random.randrange(0, len(mem))
4388
assert torch.all(v == x)
4391
choices = [alloc, free, torch.cuda.memory.empty_cache]
4393
while total >= 1024 * 1024 * 1024 / (4 * 10):
4395
(action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1])
4398
random.setstate(state)
4400
@unittest.skipIf(TEST_PYNVML, "pynvml is not available")
4401
def test_nvml_get_handler(self):
4402
if not torch.version.hip:
4403
self.assertTrue(torch.cuda._get_pynvml_handler() is not None)
4405
self.assertTrue(torch.cuda._get_amdsmi_handler() is not None)
4407
@unittest.skipIf(TEST_PYNVML, "pynvml is not available")
4408
def test_temperature(self):
4409
self.assertTrue(0 <= torch.cuda.temperature() <= 150)
4411
@unittest.skipIf(TEST_PYNVML, "pynvml is not available")
4412
def test_power_draw(self):
4413
self.assertTrue(torch.cuda.power_draw() >= 0)
4415
@unittest.skipIf(TEST_PYNVML, "pynvml is not available")
4416
def test_clock_speed(self):
4417
self.assertTrue(torch.cuda.clock_rate() >= 0)
4422
SMALL_BUFFER = 2097152
4423
LARGE_BUFFER = 20971520
4426
def get_cudagraph_segments(pool_id):
4427
segments = torch.cuda.memory_snapshot()
4428
return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
4431
def get_all_cudagraph_segments():
4432
segments = torch.cuda.memory_snapshot()
4433
return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
4436
def cudagraphify(fn, inputs, pool=None):
4437
if not TEST_CUDA_GRAPH:
4438
raise unittest.SkipTest("cuda graph test is skipped")
4440
torch.cuda.synchronize()
4441
stream = torch.cuda.Stream()
4442
stream.wait_stream(torch.cuda.current_stream())
4443
with torch.cuda.stream(stream):
4445
stream.synchronize()
4446
torch.cuda.current_stream().wait_stream(stream)
4447
torch.cuda.synchronize()
4449
graph = torch.cuda.CUDAGraph()
4450
with torch.cuda.graph(graph, stream=stream, pool=pool):
4451
static_outputs = fn(*inputs)
4453
return graph, static_outputs
4457
return torch.ones([size], device="cuda", dtype=torch.uint8)
4460
def live_blocks(pool_id):
4462
seg = get_cudagraph_segments(pool_id)
4463
for segment in get_cudagraph_segments(pool_id):
4464
for block in segment["blocks"]:
4465
blocks += block["state"] == "active_allocated"
4469
def tensor_metadata(x):
4471
"nbytes": x.untyped_storage().nbytes(),
4472
"data_ptr": x.untyped_storage().data_ptr(),
4474
"stride": x.stride(),
4477
"storage_offset": x.storage_offset(),
4481
def reconstruct_from_tensor_metadata(metadata):
4482
s = torch._C._construct_storage_from_data_pointer(
4483
metadata["data_ptr"], metadata["device"], metadata["nbytes"]
4485
t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"])
4488
storage_offset=metadata["storage_offset"],
4489
size=metadata["size"],
4490
stride=metadata["stride"],
4495
@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI")
4496
@torch.testing._internal.common_utils.markDynamoStrictTest
4497
class TestBlockStateAbsorption(TestCase):
4499
def expandable_segments(self):
4500
return EXPANDABLE_SEGMENTS
4502
def checkCheckpointedBlock(self, before_block, after_block):
4503
for field in ("size", "state"):
4504
self.assertEqual(before_block[field], after_block[field])
4506
def checkCheckpointedState(self, before_segments, after_segments):
4509
after_ptr_to_segment = {
4510
segment["address"]: segment for segment in after_segments
4513
for before_segment in before_segments:
4514
self.assertTrue(before_segment["address"] in after_ptr_to_segment)
4515
after_segment = after_ptr_to_segment[before_segment["address"]]
4525
self.assertEqual(before_segment[field], after_segment[field])
4528
len(before_segment["blocks"]), len(after_segment["blocks"])
4530
for before_block, after_block in zip(
4531
before_segment["blocks"], after_segment["blocks"]
4533
self.checkCheckpointedBlock(before_block, after_block)
4536
def setCheckpointPoolState(
4537
device, state, stale_storages_ptr, storages_deleters=None
4539
stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr]
4540
storages_deleters = (
4542
if not storages_deleters
4543
else [t.untyped_storage()._cdata for t in storages_deleters]
4545
torch._C._cuda_setCheckpointPoolState(
4546
device, state, stale_storages_ptr, storages_deleters
4549
def checkFunction(self, fn, inputs, pool=None):
4550
graph, outputs = cudagraphify(fn, inputs, pool=pool)
4552
pool_id = graph.pool()
4553
device = outputs[0].device.index
4555
segments_before_checkpoint = get_cudagraph_segments(pool_id)
4557
state = torch._C._cuda_getCheckpointState(device, pool_id)
4558
self.setCheckpointPoolState(device, state, [], [])
4560
self.checkCheckpointedState(
4561
segments_before_checkpoint, get_cudagraph_segments(pool_id)
4566
self.segment_length = len(get_all_cudagraph_segments())
4569
torch.cuda.synchronize()
4571
torch.cuda.empty_cache()
4573
self.assertEqual(len(get_all_cudagraph_segments()), self.segment_length)
4577
def test_simple(self):
4579
x = torch.zeros([SMALL_SIZE * 8], device="cuda", dtype=torch.uint8)
4581
x1 = int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE)
4582
y = int8_cuda(SMALL_SIZE) + x1
4583
z = int8_cuda(SMALL_SIZE)
4586
self.checkFunction(foo, [])
4588
def test_allocated_in_middle_of_segment(self):
4590
small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4591
return small_buffers[5].add_(2)
4593
self.checkFunction(foo, [])
4595
def test_multiple_middle_allocations(self):
4597
small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4598
return small_buffers[5], small_buffers[8]
4600
self.checkFunction(foo, [])
4602
def test_middle_allocations_contiguous(self):
4604
small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)]
4605
return small_buffers[5], small_buffers[6]
4607
self.checkFunction(foo, [])
4609
def test_additional_free_following_checkpoint(self):
4611
return (int8_cuda(MIN_BLOCK_SIZE),)
4614
return (int8_cuda(MIN_BLOCK_SIZE),)
4616
graph, outputs = cudagraphify(foo, [])
4617
pool_id = graph.pool()
4619
segments_before_checkpoint = get_cudagraph_segments(pool_id)
4621
state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4623
graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
4625
self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, [])
4629
self.checkCheckpointedState(
4630
segments_before_checkpoint, get_cudagraph_segments(pool_id)
4652
def test_tensor_dies_after_checkpoint(self):
4654
return int8_cuda(MIN_BLOCK_SIZE), int8_cuda(MIN_BLOCK_SIZE)
4656
graph, outputs = cudagraphify(foo, [])
4657
pool_id = graph.pool()
4658
device = outputs[0].device.index
4660
segments_before_checkpoint = get_cudagraph_segments(pool_id)
4661
state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4663
output_data_ptrs = [output.data_ptr() for output in outputs]
4667
self.setCheckpointPoolState(device, state, [], [])
4669
self.assertEqual(live_blocks(pool_id), 2)
4670
torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0])
4671
self.assertEqual(live_blocks(pool_id), 1)
4672
torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[1])
4673
self.assertEqual(live_blocks(pool_id), 0)
4675
def test_assigning_back_deleter_fns_to_tensor(self):
4678
int8_cuda(SMALL_BUFFER) + x,
4679
int8_cuda(SMALL_BUFFER) + x,
4680
int8_cuda(LARGE_BUFFER) + x,
4683
inp = torch.tensor([1], device="cuda")
4684
graph, outputs = cudagraphify(foo, [inp])
4685
pool_id = graph.pool()
4688
device = outputs[0].device.index
4690
for i in range(len(outputs)):
4691
self.assertTrue(outputs[i].mean(dtype=torch.float) == 2)
4693
state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id)
4695
output_ptrs = [output.untyped_storage().data_ptr() for output in outputs]
4696
ten_metadata = [tensor_metadata(t) for t in outputs]
4698
self.assertEqual(live_blocks(pool_id), 3)
4702
self.assertEqual(live_blocks(pool_id), 0)
4704
reconstructed_tensors = [
4705
reconstruct_from_tensor_metadata(metadata) for metadata in ten_metadata
4708
for i in range(len(reconstructed_tensors)):
4709
self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2)
4714
for i in range(len(reconstructed_tensors)):
4715
self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3)
4717
self.setCheckpointPoolState(
4718
device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]]
4721
self.assertEqual(live_blocks(pool_id), 3)
4723
reconstructed_tensors[0] = None
4724
self.assertEqual(live_blocks(pool_id), 2)
4726
reconstructed_tensors[1] = None
4727
self.assertEqual(live_blocks(pool_id), 1)
4730
reconstructed_tensors[2] = None
4731
self.assertEqual(live_blocks(pool_id), 1)
4733
torch._C._cuda_cudaCachingAllocator_raw_delete(output_ptrs[2])
4735
self.assertEqual(live_blocks(pool_id), 0)
4737
@skipIfNoTorchVision
4738
def test_resnet(self):
4741
m = torchvision.models.resnet50()
4745
inp = torch.rand([1, 3, 255, 255], device="cuda")
4746
self.checkFunction(m, [inp])
4748
def test_check_pool_live_allocations(self):
4750
return torch.ones([4], device="cuda")
4752
pool = torch.cuda.graph_pool_handle()
4753
graph, outputs = cudagraphify(foo, [], pool=pool)
4755
index = outputs[0].device.index
4757
def check(live_dps):
4758
return torch._C._cuda_checkPoolLiveAllocations(index, pool, live_dps)
4760
self.assertTrue(check({outputs[0].data_ptr()}))
4762
self.assertFalse(check({outputs[0].data_ptr(), 0}))
4763
self.assertFalse(check(set()))
4766
self.assertTrue(check(set()))
4768
def test_allocate_in_thread_to_pool(self):
4770
return torch.rand([4], device="cuda")
4772
pool = torch.cuda.graph_pool_handle()
4773
graph, outputs = cudagraphify(foo, [], pool=pool)
4774
device = outputs[0].device.index
4777
@contextlib.contextmanager
4778
def _use_cuda_memory_pool_manager(device, mem_pool):
4780
Context manager to use cuda graph pool for new allocations. If you use this manager
4781
all cudagraph tensors in use should be reflected in the allocator or they will be overwritten.
4782
existing_graph should already have been used in a capture, and the mem_pool must already exist.
4784
torch.cuda.synchronize()
4785
stream = torch.cuda.Stream()
4786
stream.wait_stream(torch.cuda.current_stream())
4787
stream_context = torch.cuda.stream(stream)
4788
stream_context.__enter__()
4789
torch._C._cuda_beginAllocateCurrentStreamToPool(device, mem_pool)
4793
torch._C._cuda_endAllocateCurrentStreamToPool(device, mem_pool)
4794
torch._C._cuda_releasePool(device, mem_pool)
4795
stream_context.__exit__(None, None, None)
4797
segments = get_cudagraph_segments(pool)
4798
self.assertEqual(len(get_cudagraph_segments(pool)), 1)
4802
a = int8_cuda(LARGE_BUFFER)
4803
b = int8_cuda(LARGE_BUFFER)
4806
with _use_cuda_memory_pool_manager(device, pool):
4817
a = int8_cuda(LARGE_BUFFER)
4818
b = int8_cuda(LARGE_BUFFER)
4821
graph_thread = threading.Thread(target=use_pool)
4822
no_graph_thread = threading.Thread(target=no_pool)
4823
graph_thread.start()
4824
no_graph_thread.start()
4827
no_graph_thread.join()
4830
len(get_cudagraph_segments(pool)), 2 if self.expandable_segments else 4
4835
torch.cuda.synchronize()
4837
torch.cuda.empty_cache()
4839
self.assertEqual(len(get_cudagraph_segments(pool)), 0)
4841
def test_no_triton_on_import(self):
4842
"""Test that Trition is not imported on first GPU use"""
4843
script = "import sys; import torch; torch.rand(2, device='cuda'); print('triton' in sys.modules)"
4846
subprocess.check_output(
4847
[sys.executable, "-c", script],
4850
cwd=os.path.dirname(os.path.realpath(__file__)),
4855
self.assertEqual(rc, "False", "Triton was imported when importing torch!")
4858
class TestMemPool(TestCase):
4859
def test_mempool_id(self):
4860
pool1 = torch.cuda.graph_pool_handle()
4861
pool2 = torch.cuda.MemPool().id
4864
self.assertEqual(pool1[0] == 0, pool2[0] == 0)
4868
self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
4870
def test_mempool_with_allocator(self):
4871
pool = torch.cuda.MemPool()
4874
self.assertEqual(pool.allocator, None)
4876
from torch.utils.cpp_extension import load_inline
4878
dummy_allocator_source = """
4879
#include <torch/extension.h>
4880
#include <ATen/cuda/Exceptions.h>
4881
#include <cuda_runtime_api.h>
4884
C10_EXPORT int called_dummy_alloc = 0;
4885
C10_EXPORT int called_dummy_free = 0;
4887
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
4888
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
4889
called_dummy_alloc = 123;
4891
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
4895
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
4896
called_dummy_free = 321;
4897
C10_CUDA_CHECK(cudaFree(ptr));
4901
dummy_allocator_libname = "dummy_allocator"
4902
dummy_allocator = load_inline(
4903
name=dummy_allocator_libname,
4904
cpp_sources=dummy_allocator_source,
4905
is_python_module=False,
4906
keep_intermediates=False,
4910
allocator = torch.cuda.memory.CUDAPluggableAllocator(
4915
pool = torch.cuda.MemPool(allocator.allocator())
4918
self.assertEqual(allocator.allocator(), pool.allocator)
4921
alloc_lib = ctypes.CDLL(dummy_allocator)
4922
called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc")
4923
self.assertEqual(called_dummy_alloc.value, 0)
4925
with torch.cuda.use_mem_pool(pool):
4926
out = torch.randn(1, device="cuda")
4930
self.assertEqual(called_dummy_alloc.value, 123)
4932
def test_mempool_context(self):
4933
active_pool = torch.cuda.MemPoolContext.active_pool()
4936
self.assertEqual(active_pool, None)
4938
pool = torch.cuda.MemPool()
4939
ctx = torch.cuda.MemPoolContext(pool)
4940
active_pool = torch.cuda.MemPoolContext.active_pool()
4943
self.assertEqual(active_pool, pool)
4946
active_pool = torch.cuda.MemPoolContext.active_pool()
4949
self.assertEqual(active_pool, None)
4951
def test_mempool_multithread(self):
4953
active_pool_ids = []
4955
def create_mempool_and_make_active():
4956
pool = torch.cuda.MemPool()
4957
pool_ids.extend([pool.id])
4959
ctx = torch.cuda.MemPoolContext(pool)
4960
active_pool = torch.cuda.MemPoolContext.active_pool()
4961
active_pool_ids.extend([active_pool.id])
4966
threading.Thread(target=create_mempool_and_make_active)
4967
for t in range(num_threads)
4969
for thread in threads:
4971
for thread in threads:
4976
self.assertEqual(len(set(pool_ids)), 4)
4980
self.assertEqual(len(set(active_pool_ids)), 4)
4983
@torch.testing._internal.common_utils.markDynamoStrictTest
4984
class TestCudaOptims(TestCase):
4990
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs"
4993
[optim for optim in optim_db if optim.has_capturable_arg],
4994
dtypes=[torch.float32],
4996
def test_graph_optims(self, device, dtype, optim_info):
4997
optim_cls = optim_info.optim_cls
4998
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
4999
device, dtype, optim_info, skip=("differentiable",)
5005
for optim_input in all_optim_inputs:
5006
kwargs = optim_input.kwargs
5012
for actually_do_graphs in (True, False):
5014
torch.randn((i + 5, i + 5), device=device) for i in range(2)
5015
] + [torch.randn((), device=device)]
5016
params_control = [p.clone().requires_grad_() for p in params]
5017
params_graphed = [p.clone().requires_grad_() for p in params]
5020
[torch.randn_like(p) for p in params]
5021
for _ in range(steps_warmup + steps_train)
5025
kwargs["capturable"] = False
5027
opt = optim_cls(params_control, **kwargs)
5028
for i in range(steps_warmup + steps_train):
5029
for j, p in enumerate(params_control):
5030
p.grad = grads[i][j]
5034
kwargs["capturable"] = True
5035
opt = optim_cls(params_graphed, **kwargs)
5037
for i in range(steps_warmup):
5038
for j, p in enumerate(params_graphed):
5039
p.grad = grads[i][j]
5042
if actually_do_graphs:
5043
g = torch.cuda.CUDAGraph()
5044
with torch.cuda.graph(g):
5047
for i in range(steps_train):
5048
if actually_do_graphs:
5049
for j, p in enumerate(params_graphed):
5050
p.grad.copy_(grads[i + steps_warmup][j])
5055
for j, p in enumerate(params_graphed):
5056
p.grad = grads[i + steps_warmup][j]
5059
for p_control, p_graphed in zip(params_control, params_graphed):
5060
self.assertEqual(p_control, p_graphed)
5064
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
5069
for optim in optim_db
5070
if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on
5072
dtypes=[torch.float32],
5074
def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
5075
optim_cls = optim_info.optim_cls
5080
optim_inputs = optim_info.optim_inputs_func(device=device)
5082
for optim_input in optim_inputs:
5083
kwargs = optim_input.kwargs
5084
kwargs["fused"] = True
5086
for actually_do_graphs in (
5087
(True, False) if optim_info.has_capturable_arg else (True,)
5089
params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
5090
params_control = [p.clone().requires_grad_() for p in params]
5091
params_graphed = [p.clone().requires_grad_() for p in params]
5095
[torch.randn_like(p) for p in params]
5096
for _ in range(steps_warmup + steps_train)
5098
with torch.no_grad():
5099
grads_control = [[g.clone() for g in gs] for gs in grads]
5100
grads_graphed = [[g.clone() for g in gs] for gs in grads]
5103
scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
5104
with torch.no_grad():
5105
scaler_for_control._lazy_init_scale_growth_tracker(device)
5107
scaler_for_graphed = torch.cuda.amp.GradScaler()
5108
scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
5109
with torch.no_grad():
5110
scaler_for_graphed._lazy_init_scale_growth_tracker(device)
5113
if optim_info.has_capturable_arg:
5114
kwargs["capturable"] = False
5115
opt = optim_cls(params_control, **kwargs)
5117
for i in range(steps_warmup + steps_train):
5118
for j, p in enumerate(params_control):
5119
p.grad = grads_control[i][j]
5120
scaler_for_control.step(opt)
5121
scaler_for_control.update()
5124
if optim_info.has_capturable_arg:
5125
kwargs["capturable"] = True
5126
opt = optim_cls(params_graphed, **kwargs)
5128
for i in range(steps_warmup):
5129
for j, p in enumerate(params_graphed):
5130
p.grad = grads_graphed[i][j]
5131
scaler_for_graphed.step(opt)
5132
scaler_for_graphed.update()
5134
if actually_do_graphs:
5135
g = torch.cuda.CUDAGraph()
5136
with torch.cuda.graph(g):
5137
scaler_for_graphed.step(opt)
5138
scaler_for_graphed.update()
5140
for i in range(steps_train):
5141
if actually_do_graphs:
5142
for j, p in enumerate(params_graphed):
5143
p.grad.copy_(grads_graphed[i + steps_warmup][j])
5148
for j, p in enumerate(params_graphed):
5149
p.grad = grads_graphed[i + steps_warmup][j]
5150
scaler_for_graphed.step(opt)
5151
scaler_for_graphed.update()
5153
for p_control, p_graphed in zip(params_control, params_graphed):
5154
self.assertEqual(p_control, p_graphed)
5156
@onlyNativeDeviceTypes
5158
[optim for optim in optim_db if "fused" in optim.supported_impls],
5159
dtypes=[torch.float32],
5161
def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info):
5162
device = device.split(":")[0]
5163
if device not in optim_info.supports_fused_on:
5165
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
5167
optim_inputs = optim_info.optim_inputs_func(device=device)
5168
optim_cls = optim_info.optim_cls
5169
for optim_input in optim_inputs:
5170
for _separate_unscale in (True, False):
5171
kwargs = optim_input.kwargs
5172
kwargs["fused"] = True
5173
torch.manual_seed(20)
5182
) = _create_scaling_case(
5183
optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device
5185
optimizer_kwargs = deepcopy(kwargs)
5186
optimizer_kwargs["fused"] = False
5187
if "lr" not in kwargs:
5189
optimizer_kwargs["lr"] = 1.0
5190
opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs)
5191
scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0)
5192
scaler_control = torch.amp.GradScaler(device, init_scale=128.0)
5193
tracker = TensorTracker()
5194
for input, target in data:
5195
opt_control.zero_grad()
5196
with torch.autocast(device_type=device, dtype=torch.half):
5197
output_control = mod_control(input)
5198
loss_control = loss_fn(output_control, target)
5199
scaler_control.scale(loss_control).backward()
5200
scaler_control.step(opt_control)
5201
scaler_control.update()
5203
opt_scaling.zero_grad()
5204
with torch.autocast(device_type=device, dtype=torch.half):
5205
output_scaling = mod_scaling(input)
5206
loss_scaling = loss_fn(output_scaling, target)
5207
scaler_scaling.scale(loss_scaling).backward()
5208
if _separate_unscale:
5209
scaler_scaling.unscale_(opt_scaling)
5210
scaler_scaling.step(opt_scaling)
5211
scaler_scaling.update()
5213
tracker.add(loss_control)
5214
tracker.pop_check_set(loss_scaling, self)
5215
for param_control, param_scaling in zip(
5216
mod_control.parameters(), mod_scaling.parameters()
5218
tracker.add(param_control.grad)
5219
tracker.pop_check_set(param_scaling.grad, self)
5220
tracker.add(param_control)
5221
tracker.pop_check_set(param_scaling, self)
5223
state_control, state_scaling = (
5224
opt_control.state[param_control],
5225
opt_scaling.state[param_scaling],
5228
for k in state_control:
5229
actual = state_scaling[k]
5231
actual = actual.squeeze()
5232
tracker.add(state_control[k])
5233
tracker.pop_check_set(actual, self)
5236
@parametrize("in_place_unscale", [False, True])
5238
[optim for optim in optim_db if "cuda" in optim.supports_fused_on],
5239
dtypes=[torch.float32],
5241
def test_grad_scaler_with_preset_grad_scale(
5242
self, device, dtype, optim_info, in_place_unscale
5244
weight = torch.ones((5, 5), device="cuda", requires_grad=True)
5245
weight.grad = torch.full_like(weight, fill_value=15)
5246
opt = optim_info.optim_cls([weight], lr=0.1, fused=True)
5247
scaler = torch.amp.GradScaler(init_scale=5)
5250
scaler.scale(torch.ones(5))
5252
if in_place_unscale:
5253
scaler.unscale_(opt)
5255
self.assertEqual(weight.grad, torch.full_like(weight, fill_value=3))
5258
opt.grad_scale = torch.Tensor([3]).cuda()
5262
self.assertEqual(weight.grad, torch.full_like(weight, fill_value=1))
5266
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
5268
@parametrize("foreach, fused", [(False, False), (True, False), (False, True)])
5272
for optim in optim_db
5273
if "foreach" in optim.supported_impls and "cuda" in optim.supports_fused_on
5275
dtypes=[torch.float32],
5277
def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused):
5278
torch.cuda.empty_cache()
5280
scaler = torch.amp.GradScaler(device="cuda", init_scale=4.0)
5281
g = torch.cuda.CUDAGraph()
5282
s = torch.cuda.Stream()
5284
weight = torch.ones((100,), device="cuda", requires_grad=True)
5285
opt = optim_info.optim_cls([weight], lr=0.1, foreach=foreach, fused=fused)
5286
static_input = torch.ones_like(weight)
5287
static_grad = torch.ones_like(weight)
5290
s = torch.cuda.Stream()
5291
s.wait_stream(torch.cuda.current_stream())
5292
with torch.cuda.stream(s):
5293
loss = (weight.half() * static_input).sum()
5294
scaler.scale(loss).backward()
5295
torch.cuda.current_stream().wait_stream(s)
5297
opt.zero_grad(set_to_none=True)
5300
with torch.cuda.stream(s):
5302
loss = (weight.half() * static_input).sum()
5303
scaler.scale(loss).backward()
5306
input_vals = [5, 20000, 5, 40000]
5309
expected_scales = [4, 2, 2, 1]
5310
expected_growth_trackers = [1, 0, 1, 0]
5311
expected_grad_vals = [5 * 4, float("inf"), 5 * 2, float("inf")]
5313
for data, scale, growth_tracker, grad_val in zip(
5314
input_vals, expected_scales, expected_growth_trackers, expected_grad_vals
5316
static_input.fill_(data)
5318
self.assertEqual(weight.grad, torch.full_like(weight.grad, grad_val))
5321
self.assertEqual(scaler._scale, scale)
5322
self.assertEqual(scaler._growth_tracker, growth_tracker)
5325
class TestGDS(TestCase):
5326
def _get_tmp_dir_fs_type(self):
5327
my_path = os.path.realpath("/tmp")
5329
for part in psutil.disk_partitions():
5330
if part.mountpoint == "/":
5331
root_type = part.fstype
5333
if part.mountpoint == my_path:
5337
@unittest.skip("Disabling as USE_CUFILE=0 by default in builds")
5338
def test_gds_read_write_tensors(self):
5339
if self._get_tmp_dir_fs_type() not in ("ext4", "xfs"):
5340
self.skipTest("GPUDirect Storage requires ext4/xfs for local filesystem")
5341
src1 = torch.randn(1024, device="cuda")
5342
src2 = torch.randn(2, 1024, device="cuda")
5343
torch.cuda.gds._gds_register_buffer(src1.untyped_storage())
5344
torch.cuda.gds._gds_register_buffer(src2.untyped_storage())
5345
dest1 = torch.empty(1024, device="cuda")
5346
dest2 = torch.empty(2, 1024, device="cuda")
5347
with TemporaryFileName() as f:
5348
file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR)
5349
file.save_storage(src1.untyped_storage(), offset=0)
5350
file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
5351
file.load_storage(dest1.untyped_storage(), offset=0)
5352
file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
5353
self.assertEqual(src1, dest1)
5354
self.assertEqual(src2, dest2)
5355
torch.cuda.gds._gds_deregister_buffer(src1.untyped_storage())
5356
torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage())
5359
instantiate_parametrized_tests(TestCuda)
5360
instantiate_parametrized_tests(TestCudaMallocAsync)
5361
instantiate_device_type_tests(TestCudaOptims, globals())
5363
if __name__ == "__main__":