pytorch

Форк
0
/
test_multiprocessing.py 
1092 строки · 35.5 Кб
1
# Owner(s): ["module: multiprocessing"]
2

3
import contextlib
4
import copy
5
import gc
6
import os
7
import sys
8
import time
9
import unittest
10
from sys import platform
11

12
import torch
13
import torch.cuda
14
import torch.multiprocessing as mp
15
import torch.utils.hooks
16
from torch.nn import Parameter
17
from torch.testing._internal.common_cuda import IS_JETSON
18
from torch.testing._internal.common_utils import (
19
    IS_MACOS,
20
    IS_WINDOWS,
21
    load_tests,
22
    NO_MULTIPROCESSING_SPAWN,
23
    run_tests,
24
    slowTest,
25
    TEST_WITH_ASAN,
26
    TEST_WITH_ROCM,
27
    TEST_WITH_TORCHDYNAMO,
28
    TEST_WITH_TSAN,
29
    TestCase,
30
)
31

32

33
# load_tests from common_utils is used to automatically filter tests for
34
# sharding on sandcastle. This line silences flake warnings
35
load_tests = load_tests
36

37
TEST_REPEATS = 30
38
HAS_SHM_FILES = os.path.isdir("/dev/shm")
39
MAX_WAITING_TIME_IN_SECONDS = 30
40

41
TEST_CUDA_IPC = (
42
    torch.cuda.is_available()
43
    and sys.platform != "darwin"
44
    and sys.platform != "win32"
45
    and not IS_JETSON
46
    and not TEST_WITH_ROCM
47
)  # https://github.com/pytorch/pytorch/issues/90940
48

49
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
50

51

52
class SubProcess(mp.Process):
53
    def __init__(self, tensor):
54
        super().__init__()
55
        self.tensor = tensor
56
        self.daemon = True
57

58
    def run(self):
59
        self.tensor.add_(3)
60

61

62
def _test_cuda_ipc_deadlock_actor(queue, iterations):
63
    for i in range(iterations):
64
        if not queue.empty():
65
            queue.get()
66
        time.sleep(0.01)
67

68

69
def _test_cuda_ipc_deadlock_learner(queue, iterations):
70
    net = torch.nn.LSTM(1, 1).cuda()
71
    for i in range(iterations):
72
        if not queue.full():
73
            queue.put(copy.deepcopy(net.state_dict()))
74
        time.sleep(0.01)
75

76

77
def simple_fill(queue, event):
78
    data = queue.get()
79
    data[0][:] = 4
80
    event.set()
81

82

83
def simple_pool_fill(tensor):
84
    tensor.fill_(4)
85
    return tensor.add(1)
86

87

88
def send_tensor(queue, event, device, dtype):
89
    t = torch.ones(5, 5, device=device, dtype=dtype)
90
    queue.put(t)
91
    queue.put(t)
92
    event.wait()
93

94

95
def send_and_delete_tensors(queue, event, device, dtype, count, size=5):
96
    for i in range(count):
97
        t = torch.full([size], i, device=device, dtype=dtype)
98
        queue.put(t)
99
        del t
100
    event.wait()
101

102

103
def receive_and_send_sum(queue, out_queue, event, device, dtype, count, size=5):
104
    s = torch.full([size], 0, device=device, dtype=dtype)
105
    for i in range(count):
106
        t = queue.get()
107
        s += t
108
    out_queue.put(s)
109
    event.wait()
110

111

112
def receive_and_send(queue, out_queue, event, count):
113
    for i in range(count):
114
        t = queue.get()
115
        out_queue.put(t.clone())
116
    event.wait()
117

118

119
def sum_tensors(inq, outq):
120
    with torch.cuda.device(1):
121
        tensors = inq.get()
122
        for tensor in tensors:
123
            outq.put(
124
                (
125
                    tensor.sum().item(),
126
                    tensor.get_device(),
127
                    tensor.numel(),
128
                    tensor.storage().size(),
129
                )
130
            )
131

132

133
def queue_get_exception(inqueue, outqueue):
134
    os.close(2)  # hide expected error message
135
    try:
136
        torch.zeros(5, 5).cuda()
137
    except Exception as e:
138
        outqueue.put(e)
139
    else:
140
        outqueue.put("no exception")
141

142

143
# Multiply by two in a separate stream
144
def cuda_multiply_two(queue, ready, done):
145
    ready.set()
146
    with torch.cuda.stream(torch.cuda.Stream()):
147
        cuda_event, tensor = queue.get()
148
        cuda_event.wait()
149
        tensor.mul_(2)
150
        cuda_event.record()
151
        done.set()
152
        del cuda_event
153

154

155
def requires_grad_variable_sharing(queue, ready):
156
    var = queue.get()
157
    ready.set()
158
    queue.put(var.requires_grad)
159

160

161
def integer_parameter_serialization(iparam):
162
    iparam + 1
163

164

165
def autograd_sharing(queue, ready, master_modified, device, is_parameter):
166
    var = queue.get()
167
    ready.set()
168
    master_modified.wait()
169

170
    expected_var = torch.arange(1.0, 26, device=device).view(5, 5)
171
    expected_var[0, 0] = 1000
172
    is_ok = var.data.equal(expected_var)
173
    var.data[:] = torch.ones(5, 5, device=device)
174

175
    is_ok &= var.grad is None
176
    is_ok &= not var._backward_hooks
177
    if is_parameter:
178
        is_ok &= type(var) == Parameter
179
    else:
180
        is_ok &= type(var) == torch.Tensor
181
    var._grad = torch.ones(5, 5, device=device)
182

183
    queue.put(is_ok)
184

185

186
def mixed_type_producer(queue, event):
187
    for _ in range(10):
188
        float_tensor = torch.ones(2, 2).float().cuda()
189
        byte_tensor = torch.zeros(2, 2).byte().cuda()
190

191
        queue.put(float_tensor)
192
        queue.put(byte_tensor)
193
        event.wait()
194
        event.clear()
195

196

197
def simple_autograd_function(a=1):
198
    torch.rand(3).requires_grad_(True).mean().backward()
199
    return a**2
200

201

202
@contextlib.contextmanager
203
def fs_sharing():
204
    prev_strategy = mp.get_sharing_strategy()
205
    mp.set_sharing_strategy("file_system")
206
    try:
207
        yield
208
    finally:
209
        mp.set_sharing_strategy(prev_strategy)
210

211

212
class leak_checker:
213
    def __init__(self, test_case):
214
        self.checked_pids = [os.getpid()]
215
        self.test_case = test_case
216

217
    def __enter__(self):
218
        self.next_fds = self._get_next_fds(10)
219
        return self
220

221
    def __exit__(self, *args):
222
        if torch.cuda.is_available():
223
            torch.cuda.ipc_collect()
224
        if args[0] is None:
225
            # Check that the 10th available file-descriptor at the end of the
226
            # test is no more than 4 higher than the 10th available at the
227
            # start. This attempts to catch file descriptor leaks, but allows
228
            # one-off initialization that may use up a file descriptor
229
            # TODO: Disabled because this check is too flaky
230
            # available_fds = self._get_next_fds(10)
231
            # self.test_case.assertLessEqual(
232
            #     available_fds[-1] - self.next_fds[-1], 5)
233
            self.test_case.assertFalse(self.has_shm_files())
234
        return False
235

236
    def check_pid(self, pid):
237
        self.checked_pids.append(pid)
238

239
    def _get_next_fds(self, n=1):
240
        # dup uses the lowest-numbered unused descriptor for the new descriptor
241
        fds = [os.dup(0) for i in range(n)]
242
        for fd in fds:
243
            os.close(fd)
244
        return fds
245

246
    def has_shm_files(self, wait=True):
247
        if not HAS_SHM_FILES:
248
            return False
249

250
        result = self._has_shm_files()
251
        if not result or mp.get_sharing_strategy() != "file_system" or not wait:
252
            return result
253

254
        total_waiting_time = 0
255
        waiting_time = 0.5
256

257
        while total_waiting_time <= MAX_WAITING_TIME_IN_SECONDS and result:
258
            time.sleep(waiting_time)
259
            total_waiting_time += waiting_time
260
            result = self._has_shm_files()
261

262
        return result
263

264
    def _has_shm_files(self):
265
        gc.collect()
266
        names = ["torch_" + str(pid) for pid in self.checked_pids]
267
        for filename in os.listdir("/dev/shm"):
268
            for name in names:
269
                if filename.startswith(name):
270
                    return True
271
        return False
272

273

274
@unittest.skipIf(
275
    TEST_WITH_TSAN,
276
    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
277
)
278
class TestMultiprocessing(TestCase):
279
    def tearDown(self):
280
        # This will keep tests isolated from each-other
281
        if torch.cuda.is_available():
282
            torch.cuda.ipc_collect()
283

284
    def _test_sharing(self, ctx=mp, device="cpu", dtype=torch.float, repeat=1):
285
        def test_fill():
286
            x = torch.zeros(5, 5).to(device, dtype)
287
            q = ctx.Queue()
288
            e = ctx.Event()
289

290
            data = [x, x[:, 1]]
291
            q.put(data)
292

293
            p = ctx.Process(target=simple_fill, args=(q, e))
294
            p.daemon = True
295
            lc.check_pid(p.pid)
296
            p.start()
297

298
            total_waiting_time = 0
299
            waiting_time = 0.5
300
            is_set = False
301
            # Once the child process is done, it will set the event to notify the
302
            # parent accordingly
303
            while total_waiting_time <= MAX_WAITING_TIME_IN_SECONDS and not is_set:
304
                time.sleep(waiting_time)
305
                total_waiting_time += waiting_time
306
                is_set = e.is_set()
307

308
            self.assertTrue(is_set)
309
            if device != "meta":
310
                self.assertTrue(data[0].eq(4).all())
311
                self.assertTrue(data[1].eq(4).all())
312

313
            p.join(100)
314
            self.assertFalse(p.is_alive())
315

316
        def test_receive():
317
            q = ctx.Queue()
318
            e = ctx.Event()
319

320
            p = ctx.Process(target=send_tensor, args=(q, e, device, dtype))
321
            p.daemon = True
322
            lc.check_pid(p.pid)
323
            p.start()
324

325
            t1 = q.get()
326
            t2 = q.get()
327
            if device == "meta":
328
                self.assertEqual(t1.size(), t2.size())
329
            else:
330
                self.assertTrue(t1.eq(1).all())
331
            s1 = t1.storage()
332
            s2 = t2.storage()
333
            self.assertEqual(type(s1), type(s2))
334
            self.assertEqual(s1.data_ptr(), s1.data_ptr())
335
            if device == "meta":
336
                self.assertEqual(s1.size(), s2.size())
337
            else:
338
                self.assertEqual(s1, s2)
339

340
            # We need to delete this tensors to allow producer (child process)
341
            # collect them properly
342
            del t1, t2
343

344
            # Mark the event as done and join the process
345
            e.set()
346
            p.join(100)
347
            self.assertFalse(p.is_alive())
348

349
        with leak_checker(self) as lc:
350
            for _ in range(repeat):
351
                test_fill()
352
                test_receive()
353

354
    def _test_preserve_sharing(self, ctx=mp, repeat=1):
355
        def do_test():
356
            x = torch.randn(5, 5)
357
            data = [x.storage(), x, x[2], x[:, 1]]
358
            q = ctx.Queue()
359
            q.put(data)
360
            new_data = q.get(timeout=1)
361
            self.assertEqual(new_data, data, atol=0, rtol=0)
362
            storage_cdata = data[0]._cdata
363
            self.assertEqual(new_data[0]._cdata, storage_cdata)
364
            for t in new_data[1:]:
365
                self.assertEqual(t.storage()._cdata, storage_cdata)
366

367
        with leak_checker(self):
368
            for _ in range(repeat):
369
                do_test()
370

371
    def _test_pool(self, ctx=mp, repeat=1):
372
        def do_test():
373
            p = ctx.Pool(2)
374
            for proc in p._pool:
375
                lc.check_pid(proc.pid)
376

377
            buffers = [torch.zeros(2, 2) for i in range(4)]
378
            results = p.map(simple_pool_fill, buffers, 1)
379
            self.assertEqual(len(results), len(buffers))
380
            for r in results:
381
                self.assertEqual(r, torch.ones(2, 2) * 5, atol=0, rtol=0)
382
            for b in buffers:
383
                self.assertEqual(b, torch.ones(2, 2) * 4, atol=0, rtol=0)
384

385
            p.close()
386
            p.join()
387

388
        with leak_checker(self) as lc:
389
            for _ in range(repeat):
390
                do_test()
391

392
    @unittest.skipIf(
393
        platform == "darwin", "file descriptor strategy is not supported on macOS"
394
    )
395
    @unittest.skipIf(
396
        TEST_WITH_ASAN,
397
        "seems to hang with ASAN, see https://github.com/pytorch/pytorch/issues/5326",
398
    )
399
    def test_fd_sharing(self):
400
        self._test_sharing(repeat=TEST_REPEATS)
401

402
    @unittest.skipIf(
403
        platform == "darwin", "file descriptor strategy is not supported on macOS"
404
    )
405
    def test_fd_preserve_sharing(self):
406
        self._test_preserve_sharing(repeat=TEST_REPEATS)
407

408
    @unittest.skipIf(
409
        platform == "darwin", "file descriptor strategy is not supported on macOS"
410
    )
411
    def test_fd_pool(self):
412
        self._test_pool(repeat=TEST_REPEATS)
413

414
    @unittest.skipIf(
415
        TEST_WITH_ASAN,
416
        "seems to hang with ASAN, see https://github.com/pytorch/pytorch/issues/5326",
417
    )
418
    @unittest.skipIf(
419
        TEST_WITH_TORCHDYNAMO,
420
        "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
421
    )
422
    def test_fs_sharing(self):
423
        with fs_sharing():
424
            # The test works but is very slow on MacOS, see https://github.com/pytorch/pytorch/pull/93183,
425
            # so run it only once there. The delay is in waiting for the child process to terminate (join)
426
            repeat = 1 if IS_MACOS else TEST_REPEATS
427
            self._test_sharing(repeat=repeat)
428

429
    @unittest.skipIf(
430
        TEST_WITH_TORCHDYNAMO,
431
        "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
432
    )
433
    def test_fs_preserve_sharing(self):
434
        with fs_sharing():
435
            self._test_preserve_sharing(repeat=TEST_REPEATS)
436

437
    @unittest.skipIf(
438
        TEST_WITH_TORCHDYNAMO,
439
        "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
440
    )
441
    def test_fs_pool(self):
442
        with fs_sharing():
443
            self._test_pool(repeat=TEST_REPEATS)
444

445
    @unittest.skipIf(not HAS_SHM_FILES, "don't not how to check if shm files exist")
446
    @unittest.skipIf(
447
        TEST_WITH_TORCHDYNAMO,
448
        "Fail to clean up temporary /dev/shm/torch_* file, see https://github.com/pytorch/pytorch/issues/91467",
449
    )
450
    def test_fs(self):
451
        def queue_put():
452
            x = torch.DoubleStorage(4)
453
            q = mp.Queue()
454
            self.assertFalse(lc.has_shm_files())
455
            q.put(x)
456
            time.sleep(0.05)  # queue serializes asynchronously
457
            self.assertTrue(lc.has_shm_files(wait=False))
458
            q.get()
459

460
        with fs_sharing(), leak_checker(self) as lc:
461
            for _ in range(TEST_REPEATS):
462
                queue_put()
463

464
    def test_inherit_tensor(self):
465
        t = torch.zeros(5, 5)
466
        p = SubProcess(t.share_memory_())
467
        p.start()
468
        p.join(2)
469
        if p.exitcode is None:
470
            print("test_inherit_tensor: SubProcess too slow")
471
        else:
472
            self.assertEqual(t, torch.ones(5, 5) * 3, atol=0, rtol=0)
473

474
    @unittest.skipIf(IS_WINDOWS, "Test needs to use fork multiprocessing")
475
    def test_autograd_errors(self):
476
        ctx = mp.get_context("fork")
477
        simple_autograd_function()
478
        # Autograd only uses thread when GPUs are involved
479
        if (
480
            torch.cuda.is_available()
481
            or torch.backends.mps.is_available()
482
            or torch.xpu.is_available()
483
        ):
484
            with self.assertRaisesRegex(RuntimeError, r"Unable to handle autograd"):
485
                with ctx.Pool(3) as pool:
486
                    pool.map(simple_autograd_function, [1, 2, 3])
487
        else:
488
            with ctx.Pool(3) as pool:
489
                pool.map(simple_autograd_function, [1, 2, 3])
490

491
    @unittest.skipIf(
492
        NO_MULTIPROCESSING_SPAWN, "Test needs to use spawn multiprocessing"
493
    )
494
    def test_autograd_fine_with_spawn(self):
495
        ctx = mp.get_context("spawn")
496
        simple_autograd_function()
497
        with ctx.Pool(3) as pool:
498
            pool.map(simple_autograd_function, [1, 2, 3])
499

500
    @unittest.skipIf(
501
        NO_MULTIPROCESSING_SPAWN,
502
        "Disabled for environments that \
503
                     don't support multiprocessing with spawn start method",
504
    )
505
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
506
    def test_cuda_simple(self):
507
        torch.cuda.FloatTensor([1])  # initialize CUDA outside of leak checker
508
        self._test_sharing(mp.get_context("spawn"), "cuda", torch.float)
509

510
    @unittest.skipIf(
511
        NO_MULTIPROCESSING_SPAWN,
512
        "Disabled for environments that \
513
                     don't support multiprocessing with spawn start method",
514
    )
515
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
516
    def test_cuda_memory_allocation(self):
517
        ctx = mp.get_context("spawn")
518
        q = ctx.Queue()
519
        e = ctx.Event()
520
        p = ctx.Process(
521
            target=send_and_delete_tensors, args=(q, e, "cuda", torch.int, 5)
522
        )
523
        p.start()
524
        t = []
525
        for _ in range(5):
526
            t.append(q.get())
527
        self.assertEqual(t[0], torch.full([5], 0, dtype=torch.int32))
528
        del t
529
        e.set()
530
        p.join(1)
531

532
    @unittest.skipIf(
533
        NO_MULTIPROCESSING_SPAWN,
534
        "Disabled for environments that \
535
                     don't support multiprocessing with spawn start method",
536
    )
537
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
538
    def test_cuda_ipc_deadlock(self):
539
        ctx = mp.get_context("spawn")
540
        queue = ctx.Queue(1)
541
        processes = dict(
542
            a=ctx.Process(target=_test_cuda_ipc_deadlock_actor, args=(queue, 100)),
543
            l=ctx.Process(target=_test_cuda_ipc_deadlock_learner, args=(queue, 100)),
544
        )
545

546
        for p in processes.values():
547
            p.start()
548

549
        for p in processes.values():
550
            p.join(10)
551

552
        for p in processes.values():
553
            self.assertFalse(p.is_alive())
554

555
    @slowTest
556
    @unittest.skipIf(
557
        NO_MULTIPROCESSING_SPAWN,
558
        "Disabled for environments that \
559
                     don't support multiprocessing with spawn start method",
560
    )
561
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
562
    def test_cuda_send_many(self, name=None, size=5, count=100000):
563
        ctx = mp.get_context("spawn")
564
        q1 = ctx.Queue()
565
        q2 = ctx.Queue()
566
        q3 = ctx.Queue()
567
        e1 = ctx.Event()
568
        e2 = ctx.Event()
569
        e3 = ctx.Event()
570
        p1 = ctx.Process(
571
            target=send_and_delete_tensors,
572
            args=(q1, e1, "cuda", torch.long, count, size),
573
        )
574
        p2 = ctx.Process(target=receive_and_send, args=(q1, q2, e2, count))
575
        p3 = ctx.Process(
576
            target=receive_and_send_sum,
577
            args=(q2, q3, e3, "cuda", torch.long, count, size),
578
        )
579
        p1.start()
580
        p2.start()
581
        p3.start()
582
        result = q3.get()
583
        self.assertEqual(result[0], int(count * (count - 1) / 2))
584
        del result
585
        e1.set()
586
        e2.set()
587
        e3.set()
588
        p1.join(1)
589
        p2.join(1)
590
        p3.join(1)
591

592
    @unittest.skipIf(
593
        NO_MULTIPROCESSING_SPAWN,
594
        "Disabled for environments that \
595
                     don't support multiprocessing with spawn start method",
596
    )
597
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
598
    @unittest.skipIf(not TEST_MULTIGPU, "found only 1 GPU")
599
    def test_cuda_small_tensors(self):
600
        # Check multiple small tensors which will likely use the same
601
        # underlying cached allocation
602
        ctx = mp.get_context("spawn")
603
        tensors = []
604
        for i in range(5):
605
            device = i % 2
606
            tensors += [torch.arange(i * 5.0, (i + 1) * 5).cuda(device)]
607

608
        inq = ctx.Queue()
609
        outq = ctx.Queue()
610
        inq.put(tensors)
611
        p = ctx.Process(target=sum_tensors, args=(inq, outq))
612
        p.start()
613

614
        results = []
615
        for _ in range(5):
616
            results.append(outq.get())
617
        p.join()
618

619
        for i, _tensor in enumerate(tensors):
620
            v, device, tensor_size, storage_size = results[i]
621
            self.assertEqual(v, torch.arange(i * 5.0, (i + 1) * 5).sum())
622
            self.assertEqual(device, i % 2)
623
            self.assertEqual(tensor_size, 5)
624

625
            # You might think this should be the case, but it's not!  After
626
            # data from the CUDA caching allocator goes through IPC, the
627
            # size of the storage is the size of the *cached cudaMalloc for
628
            # the entire memory block* of the storage, not just the storage.
629
            # See Note [CUDA IPC and the caching allocator] for more info
630
            #
631
            # self.assertEqual(storage_size, 5)
632

633
        # Collect current process (producer) files, make sure nothing holds
634
        # ref to the sent tensors
635
        del _tensor
636
        del tensors
637

638
        # We need to collect, as CUDA MP implementation holds one shared
639
        # memory 'file' for performance reason
640
        torch.cuda.ipc_collect()
641

642
    @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)")
643
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
644
    def test_cuda_bad_call(self):
645
        # Initialize CUDA
646
        t = torch.zeros(5, 5).cuda().cpu()
647
        inq = mp.Queue()
648
        outq = mp.Queue()
649
        p = mp.Process(target=queue_get_exception, args=(inq, outq))
650
        p.start()
651
        inq.put(t)
652
        p.join()
653
        self.assertIsInstance(outq.get(), RuntimeError)
654

655
    @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)")
656
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
657
    def test_wrong_cuda_fork(self):
658
        stderr = TestCase.runWithPytorchAPIUsageStderr(
659
            """\
660
import torch
661
from torch.multiprocessing import Process
662
def run(rank):
663
    torch.cuda.set_device(rank)
664
if __name__ == "__main__":
665
    size = 2
666
    processes = []
667
    for rank in range(size):
668
        # it would work fine without the line below
669
        x = torch.rand(20, 2).cuda()
670
        p = Process(target=run, args=(rank,))
671
        p.start()
672
        processes.append(p)
673
    for p in processes:
674
        p.join()
675
"""
676
        )
677
        self.assertRegex(stderr, "Cannot re-initialize CUDA in forked subprocess.")
678

679
    @unittest.skipIf(
680
        NO_MULTIPROCESSING_SPAWN,
681
        "Disabled for environments that \
682
                     don't support multiprocessing with spawn start method",
683
    )
684
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
685
    def test_event(self):
686
        ctx = mp.get_context("spawn")
687
        queue = ctx.Queue()
688
        ready = ctx.Event()
689
        done = ctx.Event()
690
        p = ctx.Process(target=cuda_multiply_two, args=(queue, ready, done))
691
        p.start()
692

693
        ready.wait()
694
        with torch.cuda.stream(torch.cuda.Stream()):
695
            tensor = torch.cuda.FloatTensor([1, 1, 1, 1])
696
            # Use a sleep kernel to test events. Without the event, the
697
            # multiply happens before the add.
698
            event = torch.cuda.Event(interprocess=True)
699
            torch.cuda._sleep(20000000)  # about 30 ms
700
            tensor.add_(1)
701
            event.record()
702
            queue.put((event, tensor))
703
            done.wait()  # must wait until subprocess records event
704
            event.synchronize()
705
            self.assertEqual(list(tensor), [4, 4, 4, 4])
706
        p.join()
707

708
    @staticmethod
709
    def _test_event_multiprocess_child(event, p2c, c2p):
710
        c2p.put(0)  # notify parent child is ready
711
        p2c.get()  # wait for record in parent
712
        event.synchronize()
713
        c2p.put(1)  # notify parent synchronization is done
714

715
    @unittest.skipIf(
716
        NO_MULTIPROCESSING_SPAWN,
717
        "Disabled for environments that \
718
                     don't support multiprocessing with spawn start method",
719
    )
720
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
721
    def test_event_multiprocess(self):
722
        event = torch.cuda.Event(enable_timing=False, interprocess=True)
723
        self.assertTrue(event.query())
724

725
        ctx = mp.get_context("spawn")
726
        p2c = ctx.SimpleQueue()
727
        c2p = ctx.SimpleQueue()
728
        p = ctx.Process(
729
            target=TestMultiprocessing._test_event_multiprocess_child,
730
            args=(event, p2c, c2p),
731
        )
732
        p.start()
733

734
        c2p.get()  # wait for until child process is ready
735
        torch.cuda._sleep(50000000)  # spin for about 50 ms
736
        event.record()
737
        p2c.put(0)  # notify child event is recorded
738

739
        self.assertFalse(event.query())
740
        c2p.get()  # wait for synchronization in child
741
        self.assertTrue(event.query())
742
        p.join()
743

744
    @unittest.skipIf(
745
        NO_MULTIPROCESSING_SPAWN,
746
        "Disabled for environments that \
747
                     don't support multiprocessing with spawn start method",
748
    )
749
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
750
    @unittest.skipIf(not TEST_MULTIGPU, "found only 1 GPU")
751
    def test_event_handle_multi_gpu(self):
752
        d0 = torch.device("cuda:0")
753
        d1 = torch.device("cuda:1")
754
        with torch.cuda.device(d0):
755
            e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
756

757
        with torch.cuda.device(d1):
758
            # create handle on different device from un-recorded event
759
            e0.ipc_handle()
760

761
        with torch.cuda.device(d0):
762
            e1 = torch.cuda.Event(enable_timing=False, interprocess=True)
763
            stream = torch.cuda.Stream()
764
            torch.cuda._sleep(50000000)  # spin for about 50 ms
765
            e1.record(stream)
766

767
        with torch.cuda.device(d1):
768
            # create handle on different device from recorded event
769
            e1.ipc_handle()
770

771
    @staticmethod
772
    def _test_event_handle_importer_consumer(handle, p2c, c2p):
773
        e1 = torch.cuda.Event.from_ipc_handle(0, handle)
774
        c2p.put(0)  # notify parent child is ready
775
        p2c.get()  # wait for record in parent
776
        e1.synchronize()
777
        c2p.put(1)  # notify synchronization is done in child
778
        p2c.get()  # wait for parent to finish before destructing child event
779

780
    @unittest.skipIf(
781
        NO_MULTIPROCESSING_SPAWN,
782
        "Disabled for environments that \
783
                     don't support multiprocessing with spawn start method",
784
    )
785
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
786
    def test_event_handle_importer(self):
787
        e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
788
        self.assertTrue(e0.query())
789

790
        ctx = mp.get_context("spawn")
791
        p2c = ctx.SimpleQueue()
792
        c2p = ctx.SimpleQueue()
793
        p = ctx.Process(
794
            target=TestMultiprocessing._test_event_handle_importer_consumer,
795
            args=(e0.ipc_handle(), p2c, c2p),
796
        )
797
        p.start()
798

799
        c2p.get()  # wait for child to become ready
800
        torch.cuda._sleep(50000000)  # spin for about 50 ms
801
        e0.record()
802
        p2c.put(0)  # notify child event is recorded
803

804
        self.assertFalse(e0.query())
805
        c2p.get()  # wait for synchronization in child
806
        self.assertTrue(e0.query())
807
        p2c.put(1)  # notify child that parent is done
808
        p.join()
809

810
    @staticmethod
811
    def _test_event_handle_exporter_consumer(handle, p2c, c2p):
812
        stream = torch.cuda.Stream()
813
        with torch.cuda.stream(stream):
814
            e1 = torch.cuda.Event.from_ipc_handle(torch.cuda.current_device(), handle)
815
            torch.cuda._sleep(50000000)  # spin for about 50 ms
816
            e1.record()
817
            c2p.put(0)
818
            # wait for parent process finished synchronization before
819
            # destructing e1
820
            p2c.get()
821

822
    @unittest.skipIf(
823
        NO_MULTIPROCESSING_SPAWN,
824
        "Disabled for environments that \
825
                     don't support multiprocessing with spawn start method",
826
    )
827
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
828
    def test_event_handle_exporter(self):
829
        e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
830

831
        ctx = mp.get_context("spawn")
832
        p2c = ctx.SimpleQueue()
833
        c2p = ctx.SimpleQueue()
834
        p = ctx.Process(
835
            target=TestMultiprocessing._test_event_handle_exporter_consumer,
836
            args=(e0.ipc_handle(), p2c, c2p),
837
        )
838
        p.start()
839
        # wait for event in child process is recorded
840
        c2p.get()
841

842
        self.assertFalse(e0.query())
843
        e0.synchronize()
844
        self.assertTrue(e0.query())
845
        p2c.put(0)
846
        p.join()
847

848
    def _test_empty_tensor_sharing(self, dtype, device):
849
        q = mp.Queue()
850
        empty = torch.tensor([], dtype=dtype, device=device)
851
        q.put(empty)
852
        out = q.get(timeout=1)
853
        self.assertEqual(out, empty)
854

855
    def test_empty_tensor_sharing(self):
856
        self._test_empty_tensor_sharing(torch.float32, torch.device("cpu"))
857
        self._test_empty_tensor_sharing(torch.int64, torch.device("cpu"))
858

859
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
860
    def test_empty_tensor_sharing_cuda(self):
861
        self._test_empty_tensor_sharing(torch.float32, torch.device("cuda"))
862
        self._test_empty_tensor_sharing(torch.int64, torch.device("cuda"))
863

864
    def test_empty_tensor_sharing_meta(self):
865
        self._test_empty_tensor_sharing(torch.float32, torch.device("meta"))
866
        self._test_empty_tensor_sharing(torch.int64, torch.device("meta"))
867

868
    def test_tensor_sharing_meta(self):
869
        dtype = torch.float32
870
        device = torch.device("meta")
871
        q = mp.Queue()
872
        empty = torch.tensor([1], dtype=dtype, device=device)
873
        q.put(empty)
874
        out = q.get(timeout=1)
875
        self.assertEqual(out, empty)
876

877
    def test_meta_simple(self):
878
        self._test_sharing(mp.get_context("spawn"), "meta", torch.float)
879

880
    def _test_autograd_sharing(self, var, ctx=mp, is_parameter=False):
881
        device = "cuda" if var.is_cuda else "cpu"
882

883
        ready = ctx.Event()
884
        master_modified = ctx.Event()
885
        queue = ctx.Queue()
886
        p = ctx.Process(
887
            target=autograd_sharing,
888
            args=(queue, ready, master_modified, device, is_parameter),
889
        )
890
        p.daemon = True
891
        p.start()
892

893
        # This would cause an error if we tried to serialize the hooks,
894
        # because it's a closure and pickle doesn't support closures.
895
        @torch.utils.hooks.unserializable_hook
896
        def hook(*unused):
897
            pass
898

899
        if var.requires_grad:
900
            var.register_hook(hook)
901
        var._grad = torch.zeros(5, 5, device=device)
902
        queue.put(var)
903

904
        ready.wait()
905
        var.data[0, 0] = 1000
906
        var.grad.data[:] = torch.ones(5, 5, device=device) * 4
907
        master_modified.set()
908

909
        worker_ok = queue.get()
910
        self.assertTrue(worker_ok)
911

912
        self.assertEqual(var.data, torch.ones(5, 5, device=device))
913
        self.assertEqual(var.grad.data, torch.ones(5, 5, device=device) * 4)
914
        p.join(100)
915
        self.assertFalse(p.is_alive())
916

917
    # Check sharing a cudaMalloc allocation with different types of storage.
918
    # (Issue #11422)
919
    def _test_mixed_types_cuda_sharing(self, ctx=mp):
920
        all_ones = torch.ones(2, 2).float()
921
        all_zeros = torch.zeros(2, 2).byte()
922
        queue = ctx.Queue()
923
        event = ctx.Event()
924

925
        p = ctx.Process(target=mixed_type_producer, args=(queue, event))
926

927
        p.start()
928

929
        for _ in range(10):
930
            float_tensor = queue.get()
931
            byte_tensor = queue.get()
932
            self.assertEqual(float_tensor, all_ones)
933
            self.assertEqual(byte_tensor, all_zeros)
934
            del float_tensor, byte_tensor
935
            event.set()
936

937
        time.sleep(5)
938
        p.join()
939

940
    @unittest.skipIf(
941
        TEST_WITH_ASAN,
942
        "non-deterministically hangs with ASAN https://github.com/pytorch/pytorch/issues/94024",
943
    )
944
    def test_variable_sharing(self):
945
        for requires_grad in [True, False]:
946
            var = torch.arange(1.0, 26).view(5, 5).requires_grad_(requires_grad)
947
            self._test_autograd_sharing(var)
948

949
    # See https://github.com/pytorch/pytorch/issues/14997
950
    @unittest.skipIf(TEST_WITH_ASAN, "non-deterministically hangs with ASAN")
951
    def test_leaf_variable_sharing(self):
952
        devices = ["cpu"]
953
        if torch.cuda.is_available() and not NO_MULTIPROCESSING_SPAWN and TEST_CUDA_IPC:
954
            devices.append("cuda")
955
        for device in devices:
956
            for requires_grad in [True, False]:
957
                var = (
958
                    torch.arange(1.0, 26, device=device)
959
                    .view(5, 5)
960
                    .requires_grad_(requires_grad)
961
                )
962
                self.assertTrue(var.is_leaf)
963
                ctx = mp.get_context("spawn") if device == "cuda" else mp
964
                ready = ctx.Event()
965
                queue = ctx.Queue()
966
                p = ctx.Process(
967
                    target=requires_grad_variable_sharing, args=(queue, ready)
968
                )
969
                p.daemon = True
970
                p.start()
971
                queue.put(var)
972
                ready.wait()
973
                worker_requires_grad = queue.get()
974
                self.assertTrue(worker_requires_grad == requires_grad)
975

976
    def test_non_leaf_variable_sharing(self):
977
        devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
978
        for device in devices:
979
            var0 = torch.arange(1.0, 26, device=device).view(5, 5).requires_grad_(True)
980
            var = var0 * 2
981
            # Don't use a regular Queue; it uses a background thread (which
982
            # means we can't catch the exceptions)
983
            queue = mp.SimpleQueue()
984
            self.assertRaisesRegex(
985
                RuntimeError, r"requires_grad", lambda: queue.put(var)
986
            )
987

988
    @unittest.skipIf(
989
        NO_MULTIPROCESSING_SPAWN,
990
        "Disabled for environments that \
991
                     don't support multiprocessing with spawn start method",
992
    )
993
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
994
    def test_cuda_variable_sharing(self):
995
        for requires_grad in [True, False]:
996
            var = (
997
                torch.arange(1.0, 26, device="cuda")
998
                .view(5, 5)
999
                .requires_grad_(requires_grad)
1000
            )
1001
            self._test_autograd_sharing(var, mp.get_context("spawn"))
1002

1003
    @unittest.skipIf(
1004
        NO_MULTIPROCESSING_SPAWN,
1005
        "Disabled for environments that \
1006
                     don't support multiprocessing with spawn start method",
1007
    )
1008
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1009
    def test_mixed_types_cuda_sharing(self):
1010
        self._test_mixed_types_cuda_sharing(mp.get_context("spawn"))
1011

1012
    def test_parameter_sharing(self):
1013
        param = Parameter(torch.arange(1.0, 26).view(5, 5))
1014
        self._test_autograd_sharing(param, is_parameter=True)
1015

1016
    @unittest.skipIf(
1017
        NO_MULTIPROCESSING_SPAWN,
1018
        "Disabled for environments that \
1019
                     don't support multiprocessing with spawn start method",
1020
    )
1021
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1022
    def test_cuda_parameter_sharing(self):
1023
        param = Parameter(torch.arange(1.0, 26, device="cuda").view(5, 5))
1024
        self._test_autograd_sharing(param, mp.get_context("spawn"), is_parameter=True)
1025

1026
    @unittest.skipIf(
1027
        NO_MULTIPROCESSING_SPAWN,
1028
        "Disabled for environments that \
1029
                     don't support multiprocessing with spawn start method",
1030
    )
1031
    def test_integer_parameter_serialization_cpu(self):
1032
        self._test_integer_parameter_serialization(device="cpu")
1033

1034
    @unittest.skipIf(
1035
        NO_MULTIPROCESSING_SPAWN,
1036
        "Disabled for environments that \
1037
                     don't support multiprocessing with spawn start method",
1038
    )
1039
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1040
    def test_integer_parameter_serialization_cuda(self):
1041
        self._test_integer_parameter_serialization(device="cuda")
1042

1043
    def _test_integer_parameter_serialization(self, device):
1044
        param = torch.nn.Parameter(
1045
            torch.tensor(0, dtype=torch.int64, device=device), requires_grad=False
1046
        )
1047

1048
        ctx = mp.get_context("spawn")
1049
        p = ctx.Process(target=integer_parameter_serialization, args=(param,))
1050
        p.start()
1051
        p.join()
1052

1053
        self.assertEqual(
1054
            0,
1055
            p.exitcode,
1056
            msg=f'Failed to serialize successfully for "{device}" device!',
1057
        )
1058

1059
    def test_empty_shared(self):
1060
        t = torch.tensor([])
1061
        t.share_memory_()
1062

1063
    def _test_is_shared(self):
1064
        t = torch.randn(5, 5)
1065
        self.assertFalse(t.is_shared())
1066
        t.share_memory_()
1067
        self.assertTrue(t.is_shared())
1068

1069
    @unittest.skipIf(
1070
        platform == "darwin", "file descriptor strategy is not supported on macOS"
1071
    )
1072
    def test_is_shared(self):
1073
        self._test_is_shared()
1074

1075
    def test_fs_is_shared(self):
1076
        with fs_sharing():
1077
            self._test_is_shared()
1078

1079
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
1080
    def test_is_shared_cuda(self):
1081
        t = torch.randn(5, 5).cuda()
1082
        self.assertTrue(t.is_shared())
1083

1084
    @unittest.skipIf(sys.platform != "linux", "Only runs on Linux; requires prctl(2)")
1085
    def test_set_thread_name(self):
1086
        name = "test name"
1087
        mp._set_thread_name(name)
1088
        self.assertEqual(mp._get_thread_name(), name)
1089

1090

1091
if __name__ == "__main__":
1092
    run_tests()
1093

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.