1
# Owner(s): ["module: cuda"]
7
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
8
from torch.testing._internal.common_utils import (
11
skipIfRocmVersionLessThan,
16
# NOTE: this needs to be run in a brand new process
19
print("CUDA not available, skipping tests", file=sys.stderr)
20
TestCase = NoTest # noqa: F811
23
@torch.testing._internal.common_utils.markDynamoStrictTest
24
class TestCudaPrimaryCtx(TestCase):
25
CTX_ALREADY_CREATED_ERR_MSG = (
26
"Tests defined in test_cuda_primary_ctx.py must be run in a process "
27
"where CUDA contexts are never created. Use either run_test.py or add "
28
"--subprocess to run each test in a different subprocess."
31
@skipIfRocmVersionLessThan((4, 4, 21504))
33
for device in range(torch.cuda.device_count()):
34
# Ensure context has not been created beforehand
36
torch._C._cuda_hasPrimaryContext(device),
37
TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG,
40
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
41
def test_str_repr(self):
42
x = torch.randn(1, device="cuda:1")
44
# We should have only created context on 'cuda:1'
45
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
46
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
51
# We should still have only created context on 'cuda:1'
52
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
53
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
55
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
57
x = torch.randn(1, device="cuda:1")
59
# We should have only created context on 'cuda:1'
60
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
61
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
63
y = torch.randn(1, device="cpu")
66
# We should still have only created context on 'cuda:1'
67
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
68
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
70
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
71
def test_pin_memory(self):
72
x = torch.randn(1, device="cuda:1")
74
# We should have only created context on 'cuda:1'
75
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
76
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
78
self.assertFalse(x.is_pinned())
80
# We should still have only created context on 'cuda:1'
81
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
82
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
84
x = torch.randn(3, device="cpu").pin_memory()
86
# We should still have only created context on 'cuda:1'
87
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
88
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
90
self.assertTrue(x.is_pinned())
92
# We should still have only created context on 'cuda:1'
93
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
94
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
96
x = torch.randn(3, device="cpu", pin_memory=True)
98
# We should still have only created context on 'cuda:1'
99
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
100
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
102
x = torch.zeros(3, device="cpu", pin_memory=True)
104
# We should still have only created context on 'cuda:1'
105
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
106
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
108
x = torch.empty(3, device="cpu", pin_memory=True)
110
# We should still have only created context on 'cuda:1'
111
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
112
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
116
# We should still have only created context on 'cuda:1'
117
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
118
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
121
if __name__ == "__main__":