1
# Owner(s): ["module: multiprocessing"]
10
from sys import platform
14
import torch.multiprocessing as mp
15
import torch.utils.hooks
16
from torch.nn import Parameter
17
from torch.testing._internal.common_cuda import IS_JETSON
18
from torch.testing._internal.common_utils import (
22
NO_MULTIPROCESSING_SPAWN,
27
TEST_WITH_TORCHDYNAMO,
33
# load_tests from common_utils is used to automatically filter tests for
34
# sharding on sandcastle. This line silences flake warnings
35
load_tests = load_tests
38
HAS_SHM_FILES = os.path.isdir("/dev/shm")
39
MAX_WAITING_TIME_IN_SECONDS = 30
42
torch.cuda.is_available()
43
and sys.platform != "darwin"
44
and sys.platform != "win32"
46
and not TEST_WITH_ROCM
47
) # https://github.com/pytorch/pytorch/issues/90940
49
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
52
class SubProcess(mp.Process):
53
def __init__(self, tensor):
62
def _test_cuda_ipc_deadlock_actor(queue, iterations):
63
for i in range(iterations):
69
def _test_cuda_ipc_deadlock_learner(queue, iterations):
70
net = torch.nn.LSTM(1, 1).cuda()
71
for i in range(iterations):
73
queue.put(copy.deepcopy(net.state_dict()))
77
def simple_fill(queue, event):
83
def simple_pool_fill(tensor):
88
def send_tensor(queue, event, device, dtype):
89
t = torch.ones(5, 5, device=device, dtype=dtype)
95
def send_and_delete_tensors(queue, event, device, dtype, count, size=5):
96
for i in range(count):
97
t = torch.full([size], i, device=device, dtype=dtype)
103
def receive_and_send_sum(queue, out_queue, event, device, dtype, count, size=5):
104
s = torch.full([size], 0, device=device, dtype=dtype)
105
for i in range(count):
112
def receive_and_send(queue, out_queue, event, count):
113
for i in range(count):
115
out_queue.put(t.clone())
119
def sum_tensors(inq, outq):
120
with torch.cuda.device(1):
122
for tensor in tensors:
128
tensor.storage().size(),
133
def queue_get_exception(inqueue, outqueue):
134
os.close(2) # hide expected error message
136
torch.zeros(5, 5).cuda()
137
except Exception as e:
140
outqueue.put("no exception")
143
# Multiply by two in a separate stream
144
def cuda_multiply_two(queue, ready, done):
146
with torch.cuda.stream(torch.cuda.Stream()):
147
cuda_event, tensor = queue.get()
155
def requires_grad_variable_sharing(queue, ready):
158
queue.put(var.requires_grad)
161
def integer_parameter_serialization(iparam):
165
def autograd_sharing(queue, ready, master_modified, device, is_parameter):
168
master_modified.wait()
170
expected_var = torch.arange(1.0, 26, device=device).view(5, 5)
171
expected_var[0, 0] = 1000
172
is_ok = var.data.equal(expected_var)
173
var.data[:] = torch.ones(5, 5, device=device)
175
is_ok &= var.grad is None
176
is_ok &= not var._backward_hooks
178
is_ok &= type(var) == Parameter
180
is_ok &= type(var) == torch.Tensor
181
var._grad = torch.ones(5, 5, device=device)
186
def mixed_type_producer(queue, event):
188
float_tensor = torch.ones(2, 2).float().cuda()
189
byte_tensor = torch.zeros(2, 2).byte().cuda()
191
queue.put(float_tensor)
192
queue.put(byte_tensor)
197
def simple_autograd_function(a=1):
198
torch.rand(3).requires_grad_(True).mean().backward()
202
@contextlib.contextmanager
204
prev_strategy = mp.get_sharing_strategy()
205
mp.set_sharing_strategy("file_system")
209
mp.set_sharing_strategy(prev_strategy)
213
def __init__(self, test_case):
214
self.checked_pids = [os.getpid()]
215
self.test_case = test_case
218
self.next_fds = self._get_next_fds(10)
221
def __exit__(self, *args):
222
if torch.cuda.is_available():
223
torch.cuda.ipc_collect()
225
# Check that the 10th available file-descriptor at the end of the
226
# test is no more than 4 higher than the 10th available at the
227
# start. This attempts to catch file descriptor leaks, but allows
228
# one-off initialization that may use up a file descriptor
229
# TODO: Disabled because this check is too flaky
230
# available_fds = self._get_next_fds(10)
231
# self.test_case.assertLessEqual(
232
# available_fds[-1] - self.next_fds[-1], 5)
233
self.test_case.assertFalse(self.has_shm_files())
236
def check_pid(self, pid):
237
self.checked_pids.append(pid)
239
def _get_next_fds(self, n=1):
240
# dup uses the lowest-numbered unused descriptor for the new descriptor
241
fds = [os.dup(0) for i in range(n)]
246
def has_shm_files(self, wait=True):
247
if not HAS_SHM_FILES:
250
result = self._has_shm_files()
251
if not result or mp.get_sharing_strategy() != "file_system" or not wait:
254
total_waiting_time = 0
257
while total_waiting_time <= MAX_WAITING_TIME_IN_SECONDS and result:
258
time.sleep(waiting_time)
259
total_waiting_time += waiting_time
260
result = self._has_shm_files()
264
def _has_shm_files(self):
266
names = ["torch_" + str(pid) for pid in self.checked_pids]
267
for filename in os.listdir("/dev/shm"):
269
if filename.startswith(name):
276
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
278
class TestMultiprocessing(TestCase):
280
# This will keep tests isolated from each-other
281
if torch.cuda.is_available():
282
torch.cuda.ipc_collect()
284
def _test_sharing(self, ctx=mp, device="cpu", dtype=torch.float, repeat=1):
286
x = torch.zeros(5, 5).to(device, dtype)
293
p = ctx.Process(target=simple_fill, args=(q, e))
298
total_waiting_time = 0
301
# Once the child process is done, it will set the event to notify the
303
while total_waiting_time <= MAX_WAITING_TIME_IN_SECONDS and not is_set:
304
time.sleep(waiting_time)
305
total_waiting_time += waiting_time
308
self.assertTrue(is_set)
310
self.assertTrue(data[0].eq(4).all())
311
self.assertTrue(data[1].eq(4).all())
314
self.assertFalse(p.is_alive())
320
p = ctx.Process(target=send_tensor, args=(q, e, device, dtype))
328
self.assertEqual(t1.size(), t2.size())
330
self.assertTrue(t1.eq(1).all())
333
self.assertEqual(type(s1), type(s2))
334
self.assertEqual(s1.data_ptr(), s1.data_ptr())
336
self.assertEqual(s1.size(), s2.size())
338
self.assertEqual(s1, s2)
340
# We need to delete this tensors to allow producer (child process)
341
# collect them properly
344
# Mark the event as done and join the process
347
self.assertFalse(p.is_alive())
349
with leak_checker(self) as lc:
350
for _ in range(repeat):
354
def _test_preserve_sharing(self, ctx=mp, repeat=1):
356
x = torch.randn(5, 5)
357
data = [x.storage(), x, x[2], x[:, 1]]
360
new_data = q.get(timeout=1)
361
self.assertEqual(new_data, data, atol=0, rtol=0)
362
storage_cdata = data[0]._cdata
363
self.assertEqual(new_data[0]._cdata, storage_cdata)
364
for t in new_data[1:]:
365
self.assertEqual(t.storage()._cdata, storage_cdata)
367
with leak_checker(self):
368
for _ in range(repeat):
371
def _test_pool(self, ctx=mp, repeat=1):
375
lc.check_pid(proc.pid)
377
buffers = [torch.zeros(2, 2) for i in range(4)]
378
results = p.map(simple_pool_fill, buffers, 1)
379
self.assertEqual(len(results), len(buffers))
381
self.assertEqual(r, torch.ones(2, 2) * 5, atol=0, rtol=0)
383
self.assertEqual(b, torch.ones(2, 2) * 4, atol=0, rtol=0)
388
with leak_checker(self) as lc:
389
for _ in range(repeat):
393
platform == "darwin", "file descriptor strategy is not supported on macOS"
397
"seems to hang with ASAN, see https://github.com/pytorch/pytorch/issues/5326",
399
def test_fd_sharing(self):
400
self._test_sharing(repeat=TEST_REPEATS)
403
platform == "darwin", "file descriptor strategy is not supported on macOS"
405
def test_fd_preserve_sharing(self):
406
self._test_preserve_sharing(repeat=TEST_REPEATS)
409
platform == "darwin", "file descriptor strategy is not supported on macOS"
411
def test_fd_pool(self):
412
self._test_pool(repeat=TEST_REPEATS)
416
"seems to hang with ASAN, see https://github.com/pytorch/pytorch/issues/5326",
419
TEST_WITH_TORCHDYNAMO,
420
"Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
422
def test_fs_sharing(self):
424
# The test works but is very slow on MacOS, see https://github.com/pytorch/pytorch/pull/93183,
425
# so run it only once there. The delay is in waiting for the child process to terminate (join)
426
repeat = 1 if IS_MACOS else TEST_REPEATS
427
self._test_sharing(repeat=repeat)
430
TEST_WITH_TORCHDYNAMO,
431
"Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
433
def test_fs_preserve_sharing(self):
435
self._test_preserve_sharing(repeat=TEST_REPEATS)
438
TEST_WITH_TORCHDYNAMO,
439
"Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
441
def test_fs_pool(self):
443
self._test_pool(repeat=TEST_REPEATS)
445
@unittest.skipIf(not HAS_SHM_FILES, "don't not how to check if shm files exist")
447
TEST_WITH_TORCHDYNAMO,
448
"Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
452
x = torch.DoubleStorage(4)
454
self.assertFalse(lc.has_shm_files())
456
time.sleep(0.05) # queue serializes asynchronously
457
self.assertTrue(lc.has_shm_files(wait=False))
460
with fs_sharing(), leak_checker(self) as lc:
461
for _ in range(TEST_REPEATS):
464
def test_inherit_tensor(self):
465
t = torch.zeros(5, 5)
466
p = SubProcess(t.share_memory_())
469
if p.exitcode is None:
470
print("test_inherit_tensor: SubProcess too slow")
472
self.assertEqual(t, torch.ones(5, 5) * 3, atol=0, rtol=0)
474
@unittest.skipIf(IS_WINDOWS, "Test needs to use fork multiprocessing")
475
def test_autograd_errors(self):
476
ctx = mp.get_context("fork")
477
simple_autograd_function()
478
# Autograd only uses thread when GPUs are involved
480
torch.cuda.is_available()
481
or torch.backends.mps.is_available()
482
or torch.xpu.is_available()
484
with self.assertRaisesRegex(RuntimeError, r"Unable to handle autograd"):
485
with ctx.Pool(3) as pool:
486
pool.map(simple_autograd_function, [1, 2, 3])
488
with ctx.Pool(3) as pool:
489
pool.map(simple_autograd_function, [1, 2, 3])
492
NO_MULTIPROCESSING_SPAWN, "Test needs to use spawn multiprocessing"
494
def test_autograd_fine_with_spawn(self):
495
ctx = mp.get_context("spawn")
496
simple_autograd_function()
497
with ctx.Pool(3) as pool:
498
pool.map(simple_autograd_function, [1, 2, 3])
501
NO_MULTIPROCESSING_SPAWN,
502
"Disabled for environments that \
503
don't support multiprocessing with spawn start method",
505
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
506
def test_cuda_simple(self):
507
torch.cuda.FloatTensor([1]) # initialize CUDA outside of leak checker
508
self._test_sharing(mp.get_context("spawn"), "cuda", torch.float)
511
NO_MULTIPROCESSING_SPAWN,
512
"Disabled for environments that \
513
don't support multiprocessing with spawn start method",
515
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
516
def test_cuda_memory_allocation(self):
517
ctx = mp.get_context("spawn")
521
target=send_and_delete_tensors, args=(q, e, "cuda", torch.int, 5)
527
self.assertEqual(t[0], torch.full([5], 0, dtype=torch.int32))
533
NO_MULTIPROCESSING_SPAWN,
534
"Disabled for environments that \
535
don't support multiprocessing with spawn start method",
537
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
538
def test_cuda_ipc_deadlock(self):
539
ctx = mp.get_context("spawn")
542
a=ctx.Process(target=_test_cuda_ipc_deadlock_actor, args=(queue, 100)),
543
l=ctx.Process(target=_test_cuda_ipc_deadlock_learner, args=(queue, 100)),
546
for p in processes.values():
549
for p in processes.values():
552
for p in processes.values():
553
self.assertFalse(p.is_alive())
557
NO_MULTIPROCESSING_SPAWN,
558
"Disabled for environments that \
559
don't support multiprocessing with spawn start method",
561
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
562
def test_cuda_send_many(self, name=None, size=5, count=100000):
563
ctx = mp.get_context("spawn")
571
target=send_and_delete_tensors,
572
args=(q1, e1, "cuda", torch.long, count, size),
574
p2 = ctx.Process(target=receive_and_send, args=(q1, q2, e2, count))
576
target=receive_and_send_sum,
577
args=(q2, q3, e3, "cuda", torch.long, count, size),
583
self.assertEqual(result[0], int(count * (count - 1) / 2))
593
NO_MULTIPROCESSING_SPAWN,
594
"Disabled for environments that \
595
don't support multiprocessing with spawn start method",
597
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
598
@unittest.skipIf(not TEST_MULTIGPU, "found only 1 GPU")
599
def test_cuda_small_tensors(self):
600
# Check multiple small tensors which will likely use the same
601
# underlying cached allocation
602
ctx = mp.get_context("spawn")
606
tensors += [torch.arange(i * 5.0, (i + 1) * 5).cuda(device)]
611
p = ctx.Process(target=sum_tensors, args=(inq, outq))
616
results.append(outq.get())
619
for i, _tensor in enumerate(tensors):
620
v, device, tensor_size, storage_size = results[i]
621
self.assertEqual(v, torch.arange(i * 5.0, (i + 1) * 5).sum())
622
self.assertEqual(device, i % 2)
623
self.assertEqual(tensor_size, 5)
625
# You might think this should be the case, but it's not! After
626
# data from the CUDA caching allocator goes through IPC, the
627
# size of the storage is the size of the *cached cudaMalloc for
628
# the entire memory block* of the storage, not just the storage.
629
# See Note [CUDA IPC and the caching allocator] for more info
631
# self.assertEqual(storage_size, 5)
633
# Collect current process (producer) files, make sure nothing holds
634
# ref to the sent tensors
638
# We need to collect, as CUDA MP implementation holds one shared
639
# memory 'file' for performance reason
640
torch.cuda.ipc_collect()
642
@unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)")
643
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
644
def test_cuda_bad_call(self):
646
t = torch.zeros(5, 5).cuda().cpu()
649
p = mp.Process(target=queue_get_exception, args=(inq, outq))
653
self.assertIsInstance(outq.get(), RuntimeError)
655
@unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)")
656
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
657
def test_wrong_cuda_fork(self):
658
stderr = TestCase.runWithPytorchAPIUsageStderr(
661
from torch.multiprocessing import Process
663
torch.cuda.set_device(rank)
664
if __name__ == "__main__":
667
for rank in range(size):
668
# it would work fine without the line below
669
x = torch.rand(20, 2).cuda()
670
p = Process(target=run, args=(rank,))
677
self.assertRegex(stderr, "Cannot re-initialize CUDA in forked subprocess.")
680
NO_MULTIPROCESSING_SPAWN,
681
"Disabled for environments that \
682
don't support multiprocessing with spawn start method",
684
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
685
def test_event(self):
686
ctx = mp.get_context("spawn")
690
p = ctx.Process(target=cuda_multiply_two, args=(queue, ready, done))
694
with torch.cuda.stream(torch.cuda.Stream()):
695
tensor = torch.cuda.FloatTensor([1, 1, 1, 1])
696
# Use a sleep kernel to test events. Without the event, the
697
# multiply happens before the add.
698
event = torch.cuda.Event(interprocess=True)
699
torch.cuda._sleep(20000000) # about 30 ms
702
queue.put((event, tensor))
703
done.wait() # must wait until subprocess records event
705
self.assertEqual(list(tensor), [4, 4, 4, 4])
709
def _test_event_multiprocess_child(event, p2c, c2p):
710
c2p.put(0) # notify parent child is ready
711
p2c.get() # wait for record in parent
713
c2p.put(1) # notify parent synchronization is done
716
NO_MULTIPROCESSING_SPAWN,
717
"Disabled for environments that \
718
don't support multiprocessing with spawn start method",
720
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
721
def test_event_multiprocess(self):
722
event = torch.cuda.Event(enable_timing=False, interprocess=True)
723
self.assertTrue(event.query())
725
ctx = mp.get_context("spawn")
726
p2c = ctx.SimpleQueue()
727
c2p = ctx.SimpleQueue()
729
target=TestMultiprocessing._test_event_multiprocess_child,
730
args=(event, p2c, c2p),
734
c2p.get() # wait for until child process is ready
735
torch.cuda._sleep(50000000) # spin for about 50 ms
737
p2c.put(0) # notify child event is recorded
739
self.assertFalse(event.query())
740
c2p.get() # wait for synchronization in child
741
self.assertTrue(event.query())
745
NO_MULTIPROCESSING_SPAWN,
746
"Disabled for environments that \
747
don't support multiprocessing with spawn start method",
749
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
750
@unittest.skipIf(not TEST_MULTIGPU, "found only 1 GPU")
751
def test_event_handle_multi_gpu(self):
752
d0 = torch.device("cuda:0")
753
d1 = torch.device("cuda:1")
754
with torch.cuda.device(d0):
755
e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
757
with torch.cuda.device(d1):
758
# create handle on different device from un-recorded event
761
with torch.cuda.device(d0):
762
e1 = torch.cuda.Event(enable_timing=False, interprocess=True)
763
stream = torch.cuda.Stream()
764
torch.cuda._sleep(50000000) # spin for about 50 ms
767
with torch.cuda.device(d1):
768
# create handle on different device from recorded event
772
def _test_event_handle_importer_consumer(handle, p2c, c2p):
773
e1 = torch.cuda.Event.from_ipc_handle(0, handle)
774
c2p.put(0) # notify parent child is ready
775
p2c.get() # wait for record in parent
777
c2p.put(1) # notify synchronization is done in child
778
p2c.get() # wait for parent to finish before destructing child event
781
NO_MULTIPROCESSING_SPAWN,
782
"Disabled for environments that \
783
don't support multiprocessing with spawn start method",
785
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
786
def test_event_handle_importer(self):
787
e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
788
self.assertTrue(e0.query())
790
ctx = mp.get_context("spawn")
791
p2c = ctx.SimpleQueue()
792
c2p = ctx.SimpleQueue()
794
target=TestMultiprocessing._test_event_handle_importer_consumer,
795
args=(e0.ipc_handle(), p2c, c2p),
799
c2p.get() # wait for child to become ready
800
torch.cuda._sleep(50000000) # spin for about 50 ms
802
p2c.put(0) # notify child event is recorded
804
self.assertFalse(e0.query())
805
c2p.get() # wait for synchronization in child
806
self.assertTrue(e0.query())
807
p2c.put(1) # notify child that parent is done
811
def _test_event_handle_exporter_consumer(handle, p2c, c2p):
812
stream = torch.cuda.Stream()
813
with torch.cuda.stream(stream):
814
e1 = torch.cuda.Event.from_ipc_handle(torch.cuda.current_device(), handle)
815
torch.cuda._sleep(50000000) # spin for about 50 ms
818
# wait for parent process finished synchronization before
823
NO_MULTIPROCESSING_SPAWN,
824
"Disabled for environments that \
825
don't support multiprocessing with spawn start method",
827
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
828
def test_event_handle_exporter(self):
829
e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
831
ctx = mp.get_context("spawn")
832
p2c = ctx.SimpleQueue()
833
c2p = ctx.SimpleQueue()
835
target=TestMultiprocessing._test_event_handle_exporter_consumer,
836
args=(e0.ipc_handle(), p2c, c2p),
839
# wait for event in child process is recorded
842
self.assertFalse(e0.query())
844
self.assertTrue(e0.query())
848
def _test_empty_tensor_sharing(self, dtype, device):
850
empty = torch.tensor([], dtype=dtype, device=device)
852
out = q.get(timeout=1)
853
self.assertEqual(out, empty)
855
def test_empty_tensor_sharing(self):
856
self._test_empty_tensor_sharing(torch.float32, torch.device("cpu"))
857
self._test_empty_tensor_sharing(torch.int64, torch.device("cpu"))
859
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
860
def test_empty_tensor_sharing_cuda(self):
861
self._test_empty_tensor_sharing(torch.float32, torch.device("cuda"))
862
self._test_empty_tensor_sharing(torch.int64, torch.device("cuda"))
864
def test_empty_tensor_sharing_meta(self):
865
self._test_empty_tensor_sharing(torch.float32, torch.device("meta"))
866
self._test_empty_tensor_sharing(torch.int64, torch.device("meta"))
868
def test_tensor_sharing_meta(self):
869
dtype = torch.float32
870
device = torch.device("meta")
872
empty = torch.tensor([1], dtype=dtype, device=device)
874
out = q.get(timeout=1)
875
self.assertEqual(out, empty)
877
def test_meta_simple(self):
878
self._test_sharing(mp.get_context("spawn"), "meta", torch.float)
880
def _test_autograd_sharing(self, var, ctx=mp, is_parameter=False):
881
device = "cuda" if var.is_cuda else "cpu"
884
master_modified = ctx.Event()
887
target=autograd_sharing,
888
args=(queue, ready, master_modified, device, is_parameter),
893
# This would cause an error if we tried to serialize the hooks,
894
# because it's a closure and pickle doesn't support closures.
895
@torch.utils.hooks.unserializable_hook
899
if var.requires_grad:
900
var.register_hook(hook)
901
var._grad = torch.zeros(5, 5, device=device)
905
var.data[0, 0] = 1000
906
var.grad.data[:] = torch.ones(5, 5, device=device) * 4
907
master_modified.set()
909
worker_ok = queue.get()
910
self.assertTrue(worker_ok)
912
self.assertEqual(var.data, torch.ones(5, 5, device=device))
913
self.assertEqual(var.grad.data, torch.ones(5, 5, device=device) * 4)
915
self.assertFalse(p.is_alive())
917
# Check sharing a cudaMalloc allocation with different types of storage.
919
def _test_mixed_types_cuda_sharing(self, ctx=mp):
920
all_ones = torch.ones(2, 2).float()
921
all_zeros = torch.zeros(2, 2).byte()
925
p = ctx.Process(target=mixed_type_producer, args=(queue, event))
930
float_tensor = queue.get()
931
byte_tensor = queue.get()
932
self.assertEqual(float_tensor, all_ones)
933
self.assertEqual(byte_tensor, all_zeros)
934
del float_tensor, byte_tensor
942
"non-deterministically hangs with ASAN https://github.com/pytorch/pytorch/issues/94024",
944
def test_variable_sharing(self):
945
for requires_grad in [True, False]:
946
var = torch.arange(1.0, 26).view(5, 5).requires_grad_(requires_grad)
947
self._test_autograd_sharing(var)
949
# See https://github.com/pytorch/pytorch/issues/14997
950
@unittest.skipIf(TEST_WITH_ASAN, "non-deterministically hangs with ASAN")
951
def test_leaf_variable_sharing(self):
953
if torch.cuda.is_available() and not NO_MULTIPROCESSING_SPAWN and TEST_CUDA_IPC:
954
devices.append("cuda")
955
for device in devices:
956
for requires_grad in [True, False]:
958
torch.arange(1.0, 26, device=device)
960
.requires_grad_(requires_grad)
962
self.assertTrue(var.is_leaf)
963
ctx = mp.get_context("spawn") if device == "cuda" else mp
967
target=requires_grad_variable_sharing, args=(queue, ready)
973
worker_requires_grad = queue.get()
974
self.assertTrue(worker_requires_grad == requires_grad)
976
def test_non_leaf_variable_sharing(self):
977
devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
978
for device in devices:
979
var0 = torch.arange(1.0, 26, device=device).view(5, 5).requires_grad_(True)
981
# Don't use a regular Queue; it uses a background thread (which
982
# means we can't catch the exceptions)
983
queue = mp.SimpleQueue()
984
self.assertRaisesRegex(
985
RuntimeError, r"requires_grad", lambda: queue.put(var)
989
NO_MULTIPROCESSING_SPAWN,
990
"Disabled for environments that \
991
don't support multiprocessing with spawn start method",
993
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
994
def test_cuda_variable_sharing(self):
995
for requires_grad in [True, False]:
997
torch.arange(1.0, 26, device="cuda")
999
.requires_grad_(requires_grad)
1001
self._test_autograd_sharing(var, mp.get_context("spawn"))
1004
NO_MULTIPROCESSING_SPAWN,
1005
"Disabled for environments that \
1006
don't support multiprocessing with spawn start method",
1008
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1009
def test_mixed_types_cuda_sharing(self):
1010
self._test_mixed_types_cuda_sharing(mp.get_context("spawn"))
1012
def test_parameter_sharing(self):
1013
param = Parameter(torch.arange(1.0, 26).view(5, 5))
1014
self._test_autograd_sharing(param, is_parameter=True)
1017
NO_MULTIPROCESSING_SPAWN,
1018
"Disabled for environments that \
1019
don't support multiprocessing with spawn start method",
1021
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1022
def test_cuda_parameter_sharing(self):
1023
param = Parameter(torch.arange(1.0, 26, device="cuda").view(5, 5))
1024
self._test_autograd_sharing(param, mp.get_context("spawn"), is_parameter=True)
1027
NO_MULTIPROCESSING_SPAWN,
1028
"Disabled for environments that \
1029
don't support multiprocessing with spawn start method",
1031
def test_integer_parameter_serialization_cpu(self):
1032
self._test_integer_parameter_serialization(device="cpu")
1035
NO_MULTIPROCESSING_SPAWN,
1036
"Disabled for environments that \
1037
don't support multiprocessing with spawn start method",
1039
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1040
def test_integer_parameter_serialization_cuda(self):
1041
self._test_integer_parameter_serialization(device="cuda")
1043
def _test_integer_parameter_serialization(self, device):
1044
param = torch.nn.Parameter(
1045
torch.tensor(0, dtype=torch.int64, device=device), requires_grad=False
1048
ctx = mp.get_context("spawn")
1049
p = ctx.Process(target=integer_parameter_serialization, args=(param,))
1056
msg=f'Failed to serialize successfully for "{device}" device!',
1059
def test_empty_shared(self):
1060
t = torch.tensor([])
1063
def _test_is_shared(self):
1064
t = torch.randn(5, 5)
1065
self.assertFalse(t.is_shared())
1067
self.assertTrue(t.is_shared())
1070
platform == "darwin", "file descriptor strategy is not supported on macOS"
1072
def test_is_shared(self):
1073
self._test_is_shared()
1075
def test_fs_is_shared(self):
1077
self._test_is_shared()
1079
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
1080
def test_is_shared_cuda(self):
1081
t = torch.randn(5, 5).cuda()
1082
self.assertTrue(t.is_shared())
1084
@unittest.skipIf(sys.platform != "linux", "Only runs on Linux; requires prctl(2)")
1085
def test_set_thread_name(self):
1087
mp._set_thread_name(name)
1088
self.assertEqual(mp._get_thread_name(), name)
1091
if __name__ == "__main__":