pytorch

Форк
0
/
test_dataloader.py 
3515 строк · 123.9 Кб
1
# Owner(s): ["module: dataloader"]
2

3
import ctypes
4
import errno
5
import faulthandler
6
import functools
7
import gc
8
import itertools
9
import math
10
import operator
11
import os
12
import signal
13
import sys
14
import tempfile
15
import time
16
import unittest
17
import warnings
18

19
import torch
20
import torch.utils.data.datapipes as dp
21
from torch import multiprocessing as mp
22
from torch._utils import ExceptionWrapper
23
from torch.testing._internal.common_device_type import instantiate_device_type_tests
24
from torch.testing._internal.common_utils import (
25
    IS_CI,
26
    IS_JETSON,
27
    IS_MACOS,
28
    IS_SANDCASTLE,
29
    IS_WINDOWS,
30
    load_tests,
31
    NO_MULTIPROCESSING_SPAWN,
32
    parametrize,
33
    run_tests,
34
    skipIfNoDill,
35
    skipIfRocm,
36
    slowTest,
37
    TEST_CUDA,
38
    TEST_NUMPY,
39
    TEST_WITH_ASAN,
40
    TEST_WITH_ROCM,
41
    TEST_WITH_TSAN,
42
    TestCase,
43
)
44
from torch.utils.data import (
45
    _utils,
46
    ChainDataset,
47
    ConcatDataset,
48
    DataLoader,
49
    Dataset,
50
    IterableDataset,
51
    IterDataPipe,
52
    StackDataset,
53
    Subset,
54
    TensorDataset,
55
)
56
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
57
from torch.utils.data.datapipes.iter import IterableWrapper
58
from torch.utils.data.dataset import random_split
59

60

61
try:
62
    import psutil
63

64
    HAS_PSUTIL = True
65
except ModuleNotFoundError:
66
    HAS_PSUTIL = False
67
    psutil = None
68
    err_msg = (
69
        "psutil not found. Some critical data loader tests relying on it "
70
        "(e.g., TestDataLoader.test_proper_exit) will not run."
71
    )
72
    if IS_CI:
73
        raise ModuleNotFoundError(err_msg) from None
74
    else:
75
        warnings.warn(err_msg)
76

77

78
try:
79
    import numpy as np
80

81
    HAS_NUMPY = True
82
except ModuleNotFoundError:
83
    HAS_NUMPY = False
84
    np = None
85
skipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy")
86

87
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
88
# sharding on sandcastle. This line silences flake warnings
89
load_tests = load_tests
90

91
TEST_CUDA_IPC = (
92
    torch.cuda.is_available()
93
    and sys.platform != "darwin"
94
    and sys.platform != "win32"
95
    and not IS_JETSON
96
    and not TEST_WITH_ROCM
97
)  # https://github.com/pytorch/pytorch/issues/90940
98

99
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
100

101
if not NO_MULTIPROCESSING_SPAWN:
102
    # We want to use `spawn` if able because some of our tests check that the
103
    # data loader terminiates gracefully. To prevent hanging in the testing
104
    # process, such data loaders are run in a separate subprocess.
105
    #
106
    # We also want to test the `pin_memory=True` configuration, thus `spawn` is
107
    # required to launch such processes and they initialize the CUDA context.
108
    #
109
    # Mixing different start method is a recipe for disaster (e.g., using a fork
110
    # `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally
111
    # to avoid bugs.
112
    #
113
    # Get a multiprocessing context because some test / third party library will
114
    # set start_method when imported, and setting again triggers `RuntimeError`.
115
    mp = mp.get_context(method="spawn")
116

117

118
# 60s of timeout?
119
# Yes, in environments where physical CPU resources are shared, e.g., CI, the
120
# time for a inter-process communication can be highly varying.  With 15~17s of
121
# timeout, we have observed flakiness in some CI builds (see
122
# pytorch/pytorch#14501, pytorch/pytorch#16608).  We follow the CPython
123
# multiprocessing setup and set the timeout to 60s here:
124
#
125
# https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73
126
JOIN_TIMEOUT = 60.0  # seconds
127

128

129
supported_multiprocessing_contexts = [None] + list(
130
    torch.multiprocessing.get_all_start_methods()
131
)
132

133

134
# collate_fn that returns the batch cloned; defined globally here for pickle purposes.
135
def _clone_collate(b):
136
    return [x.clone() for x in b]
137

138

139
@unittest.skipIf(
140
    TEST_WITH_TSAN,
141
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
142
    "fork is not supported. Dying (set die_after_fork=0 to override)",
143
)
144
class TestDatasetRandomSplit(TestCase):
145
    def test_lengths_must_equal_dataset_size(self):
146
        with self.assertRaises(ValueError):
147
            random_split([1, 2, 3, 4], [1, 2])
148

149
    def test_splits_have_correct_size(self):
150
        splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
151
        self.assertEqual(len(splits), 2)
152
        self.assertEqual(len(splits[0]), 2)
153
        self.assertEqual(len(splits[1]), 4)
154

155
        splits = random_split([1, 2, 3, 4, 5, 6], [0.5, 0.5])
156
        self.assertEqual(len(splits), 2)
157
        self.assertEqual(len(splits[0]), 3)
158
        self.assertEqual(len(splits[1]), 3)
159

160
        # Odd size splits
161
        self.assertEqual(
162
            len(
163
                random_split(
164
                    range(3), [0.5, 0.5], generator=torch.Generator().manual_seed(1)
165
                )
166
            ),
167
            2,
168
        )
169

170
        # Odd sized round-robin splits
171
        splits = random_split(
172
            range(106), [0.1, 0.2, 0.3, 0.4], generator=torch.Generator().manual_seed(1)
173
        )
174
        self.assertEqual(len(splits[0]), 11)
175
        self.assertEqual(len(splits[1]), 22)
176
        self.assertEqual(len(splits[2]), 31)
177
        self.assertEqual(len(splits[3]), 42)
178

179
    def test_splits_are_mutually_exclusive(self):
180
        data = [5, 2, 3, 4, 1, 6]
181
        splits = random_split(data, [2, 4])
182
        all_values = []
183
        all_values.extend(list(splits[0]))
184
        all_values.extend(list(splits[1]))
185
        data.sort()
186
        all_values.sort()
187
        self.assertListEqual(data, all_values)
188

189
        splits = random_split(data, [0.33, 0.67])
190
        all_values = []
191
        all_values.extend(list(splits[0]))
192
        all_values.extend(list(splits[1]))
193
        data.sort()
194
        all_values.sort()
195
        self.assertListEqual(data, all_values)
196

197
        data = [1, 2, 3, 4]
198
        splits = random_split(data, [0.25, 0.75])
199
        all_values = []
200
        all_values.extend(list(splits[0]))
201
        all_values.extend(list(splits[1]))
202
        data.sort()
203
        all_values.sort()
204
        self.assertListEqual(data, all_values)
205

206
    def test_splits_indexing_type(self):
207
        r"""Indices generated by random_split
208
        should be of integer type
209
        """
210

211
        class CustomDataset:
212
            def __init__(self, test_object, custom_list):
213
                self.data = custom_list
214
                self.test_object = test_object
215

216
            def __getitem__(self, key):
217
                self.test_object.assertEqual(type(key), int)
218
                return self.data[key]
219

220
            def __len__(self):
221
                return len(self.data)
222

223
        x = [1, 2, 3, 4, 5]
224
        dataset = CustomDataset(self, x)
225
        dataset = random_split(dataset, [5])[0]
226
        data_loader = DataLoader(dataset)
227
        for batch in data_loader:
228
            pass
229

230
        # fractional splitting
231
        dataset = CustomDataset(self, x)
232
        dataset = random_split(dataset, [1.0])[0]
233
        data_loader = DataLoader(dataset)
234
        for batch in data_loader:
235
            pass
236

237
    def test_splits_reproducibility(self):
238
        self.assertEqual(
239
            [
240
                list(x)
241
                for x in random_split(
242
                    range(10), [3, 7], generator=torch.Generator().manual_seed(1)
243
                )
244
            ],
245
            [[5, 6, 1], [2, 0, 8, 9, 3, 7, 4]],
246
        )
247
        self.assertEqual(
248
            random_split(
249
                range(100), [60, 40], generator=torch.Generator().manual_seed(42)
250
            ),
251
            random_split(
252
                range(100), [60, 40], generator=torch.Generator().manual_seed(42)
253
            ),
254
        )
255
        self.assertEqual(
256
            random_split(
257
                range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
258
            ),
259
            random_split(
260
                range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
261
            ),
262
        )
263
        self.assertEqual(
264
            random_split(
265
                range(100),
266
                [0.33, 0.33, 0.34],
267
                generator=torch.Generator().manual_seed(42),
268
            ),
269
            random_split(
270
                range(100),
271
                [0.33, 0.33, 0.34],
272
                generator=torch.Generator().manual_seed(42),
273
            ),
274
        )
275

276
    def test_incomplete_fractional_splits(self):
277
        with self.assertRaises(ValueError):
278
            # should raise since the sum of fractions is not 1
279
            random_split([1, 2, 3, 4], [0.1])
280

281
        with self.assertRaises(ValueError):
282
            # should raise since fraction > 1
283
            random_split([1, 2, 3, 4], [1.1])
284

285
    def test_splits_generator(self):
286
        # A random_split without a specific generator should affect the default one
287
        state = torch.get_rng_state()
288
        a = torch.rand(10)
289
        torch.set_rng_state(state)
290
        random_split(range(10), [5, 5])
291
        b = torch.rand(10)
292
        self.assertNotEqual(a, b)
293

294
        # A random_split with a specific generator should not affect the default one
295
        state = torch.get_rng_state()
296
        a = torch.rand(10)
297
        torch.set_rng_state(state)
298
        random_split(range(10), [5, 5], generator=torch.Generator().manual_seed(42))
299
        b = torch.rand(10)
300
        self.assertEqual(a, b)
301

302
    def test_slicing_of_subset_of_dataset(self):
303
        # Testing slicing a subset initialized with a dataset
304
        dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
305
        subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
306
        self.assertEqual(subset_of_dataset[:], dataset[:])
307
        self.assertEqual(subset_of_dataset[1:2], dataset[1:2])
308
        self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2])
309
        # Testing slicing of subset from random split
310
        subset1, subset2 = random_split(dataset, [3, 2])
311
        self.assertEqual(subset1[:], dataset[subset1.indices[:]])
312
        self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]])
313
        self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]])
314

315
    def test_slicing_of_subset_of_subset(self):
316
        # Testing slicing a subset initialized with a subset
317
        dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
318
        subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
319
        subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4])
320
        self.assertEqual(subset_of_subset[:], dataset[:])
321
        self.assertEqual(subset_of_subset[0:2], dataset[0:2])
322
        self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2])
323
        # Testing slicing of subset of subset from random split
324
        subset1, subset2 = random_split(dataset, [4, 1])
325
        subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1])
326
        idx = [subset1.indices[i] for i in subset_of_subset1.indices]
327
        self.assertEqual(subset_of_subset1[:], dataset[idx.copy()])
328
        self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]])
329
        self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]])
330

331

332
class CUDACountingDataset(Dataset):
333
    def __init__(self, n):
334
        super().__init__()
335
        self.n = n
336

337
    def __getitem__(self, i):
338
        return torch.as_tensor(i, device="cuda")
339

340
    def __len__(self):
341
        return self.n
342

343

344
class CountingDataset(Dataset):
345
    def __init__(self, n):
346
        super().__init__()
347
        self.n = n
348

349
    def __getitem__(self, i):
350
        return i
351

352
    def __len__(self):
353
        return self.n
354

355

356
class CountingIterableDataset(IterableDataset):
357
    def __init__(self, n):
358
        super().__init__()
359
        self.n = n
360

361
    def __iter__(self):
362
        return iter(range(self.n))
363

364
    def __len__(self):
365
        return self.n
366

367

368
@unittest.skipIf(
369
    TEST_WITH_TSAN,
370
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
371
    "fork is not supported. Dying (set die_after_fork=0 to override)",
372
)
373
class TestTensorDataset(TestCase):
374
    def test_len(self):
375
        source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
376
        self.assertEqual(len(source), 15)
377

378
    def test_getitem(self):
379
        t = torch.randn(15, 10, 2, 3, 4, 5)
380
        l = torch.randn(15, 10)
381
        source = TensorDataset(t, l)
382
        for i in range(15):
383
            self.assertEqual(t[i], source[i][0])
384
            self.assertEqual(l[i], source[i][1])
385

386
    def test_getitem_1d(self):
387
        t = torch.randn(15)
388
        l = torch.randn(15)
389
        source = TensorDataset(t, l)
390
        for i in range(15):
391
            self.assertEqual(t[i], source[i][0])
392
            self.assertEqual(l[i], source[i][1])
393

394
    def test_single_tensor(self):
395
        t = torch.randn(5, 10)
396
        source = TensorDataset(t)
397
        self.assertEqual(len(source), 5)
398
        for i in range(5):
399
            self.assertEqual(t[i], source[i][0])
400

401
    def test_many_tensors(self):
402
        t0 = torch.randn(5, 10, 2, 3, 4, 5)
403
        t1 = torch.randn(5, 10)
404
        t2 = torch.randn(5, 10, 2, 5)
405
        t3 = torch.randn(5, 10, 3, 7)
406
        source = TensorDataset(t0, t1, t2, t3)
407
        self.assertEqual(len(source), 5)
408
        for i in range(5):
409
            self.assertEqual(t0[i], source[i][0])
410
            self.assertEqual(t1[i], source[i][1])
411
            self.assertEqual(t2[i], source[i][2])
412
            self.assertEqual(t3[i], source[i][3])
413

414

415
@unittest.skipIf(
416
    TEST_WITH_TSAN,
417
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
418
    "fork is not supported. Dying (set die_after_fork=0 to override)",
419
)
420
class TestStackDataset(TestCase):
421
    def test_empty(self):
422
        with self.assertRaisesRegex(
423
            ValueError, "At least one dataset should be passed"
424
        ):
425
            StackDataset()
426

427
    def test_mixed(self):
428
        with self.assertRaisesRegex(ValueError, "Supported either"):
429
            StackDataset(
430
                TensorDataset(torch.randn(15, 10)), a=TensorDataset(torch.randn(10, 15))
431
            )
432

433
    def test_size_mismatch(self):
434
        with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
435
            StackDataset(
436
                TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(10, 15))
437
            )
438
        with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
439
            StackDataset(
440
                a=TensorDataset(torch.randn(15, 10)),
441
                b=TensorDataset(torch.randn(10, 15)),
442
            )
443

444
    def test_len(self):
445
        source = StackDataset(
446
            TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(15))
447
        )
448
        self.assertEqual(len(source), 15)
449
        source = StackDataset(TensorDataset(torch.randn(15, 10)))
450
        self.assertEqual(len(source), 15)
451
        source = StackDataset(
452
            a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(15))
453
        )
454
        self.assertEqual(len(source), 15)
455
        source = StackDataset(a=TensorDataset(torch.randn(15, 10)))
456
        self.assertEqual(len(source), 15)
457

458
    def test_single(self):
459
        t = TensorDataset(torch.randn(15, 10))
460
        source = StackDataset(t)
461
        for i in range(15):
462
            self.assertEqual(t[i], source[i][0])
463
        source = StackDataset(a=t)
464
        for i in range(15):
465
            self.assertEqual(t[i], source[i]["a"])
466

467
    def test_getitem(self):
468
        t = TensorDataset(torch.randn(15, 10))
469
        l = TensorDataset(torch.randn(15, 5, 4))
470
        source = StackDataset(t, l)
471
        for i in range(15):
472
            self.assertEqual(t[i], source[i][0])
473
            self.assertEqual(l[i], source[i][1])
474
        source = StackDataset(a=t, b=l)
475
        for i in range(15):
476
            self.assertEqual(t[i], source[i]["a"])
477
            self.assertEqual(l[i], source[i]["b"])
478

479
    def test_getitems(self):
480
        class GetItemsDataset(Dataset):
481
            def __init__(self) -> None:
482
                self.data = torch.randn(4)
483

484
            def __getitem__(self, item):
485
                return self.data[item]
486

487
            def __getitems__(self, items):
488
                return self.data[items]
489

490
            def __len__(self):
491
                return 4
492

493
        t = GetItemsDataset()
494
        l = [1, 2, 3, 4]
495

496
        source = StackDataset(t, l)
497
        batch = source.__getitems__([0, 1, 2, 3])
498
        for i in range(4):
499
            self.assertEqual(t[i], batch[i][0])
500
            self.assertEqual(l[i], batch[i][1])
501

502
        source = StackDataset(t=t, l=l)
503
        batch = source.__getitems__([0, 1, 2, 3])
504
        for i in range(4):
505
            self.assertEqual(t[i], batch[i]["t"])
506
            self.assertEqual(l[i], batch[i]["l"])
507

508
    def test_getitems_raises_index_error(self):
509
        class GetItemsDataset(Dataset):
510
            def __init__(self) -> None:
511
                self.data = torch.randn(4)
512

513
            def __getitem__(self, item):
514
                return self.data[item]
515

516
            def __getitems__(self, items):
517
                return self.data[items]
518

519
            def __len__(self):
520
                return 4
521

522
        t = GetItemsDataset()
523
        l = [1, 2, 3, 4]
524

525
        source = StackDataset(t, l)
526

527
        with self.assertRaises(IndexError):
528
            source.__getitems__([0, 4])
529

530
    def test_getitems_value_error(self):
531
        class GetItemsDataset(Dataset):
532
            def __init__(self) -> None:
533
                self.data = torch.randn(4)
534

535
            def __getitem__(self, item):
536
                return self.data[item]
537

538
            def __getitems__(self, items):
539
                return self.data[items][:-1]  # return less
540

541
            def __len__(self):
542
                return 4
543

544
        t = GetItemsDataset()
545
        l = [1, 2, 3, 4]
546

547
        source = StackDataset(t, l)
548

549
        with self.assertRaisesRegex(
550
            ValueError, "Nested dataset's output size mismatch. Expected 4, got 3"
551
        ):
552
            source.__getitems__([0, 1, 2, 3])
553

554

555
@unittest.skipIf(
556
    TEST_WITH_TSAN,
557
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
558
    "fork is not supported. Dying (set die_after_fork=0 to override)",
559
)
560
class TestConcatDataset(TestCase):
561
    def test_concat_two_singletons(self):
562
        result = ConcatDataset([[0], [1]])
563
        self.assertEqual(2, len(result))
564
        self.assertEqual(0, result[0])
565
        self.assertEqual(1, result[1])
566

567
    def test_concat_two_non_singletons(self):
568
        result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
569
        self.assertEqual(10, len(result))
570
        self.assertEqual(0, result[0])
571
        self.assertEqual(5, result[5])
572

573
    def test_concat_two_non_singletons_with_empty(self):
574
        # Adding an empty dataset somewhere is correctly handled
575
        result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]])
576
        self.assertEqual(10, len(result))
577
        self.assertEqual(0, result[0])
578
        self.assertEqual(5, result[5])
579

580
    def test_concat_raises_index_error(self):
581
        result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
582
        with self.assertRaises(IndexError):
583
            # this one goes to 11
584
            result[11]
585

586
    def test_add_dataset(self):
587
        d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
588
        d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
589
        d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
590
        result = d1 + d2 + d3
591
        self.assertEqual(21, len(result))
592
        self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
593
        self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
594
        self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
595

596
    def test_iterable_dataset_err(self):
597
        d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
598
        it1 = CountingIterableDataset(5)
599
        it2 = CountingIterableDataset(10)
600

601
        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
602
            ConcatDataset([d1, it2, it1])
603

604
        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
605
            ConcatDataset([it2])
606

607
        with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
608
            ConcatDataset([it1, d1])
609

610

611
# takes in dummy var so this can also be used as a `worker_init_fn`
612
def set_faulthander_if_available(_=None):
613
    faulthandler.enable(sys.__stderr__)
614
    if not IS_WINDOWS:
615
        # windows does not have faulthandler.register
616
        # chain=False prevents the default behavior of killing the process
617
        faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False)
618

619

620
set_faulthander_if_available()
621

622

623
# Process `pid` must have called `set_faulthander_if_available`
624
def print_traces_of_all_threads(pid):
625
    if not IS_WINDOWS:
626
        # use the custom signal if available
627
        os.kill(pid, signal.SIGUSR1)
628
    else:
629
        # otherwise we can still use the handler given by faulthandler.enable()
630
        # at the cost of killing the process.
631
        os.kill(pid, signal.SIGSEGV)
632

633
    # wait in parent process to give subprocess some time to print
634
    time.sleep(5)
635

636

637
# The following `ErrorTrackingProcess` stores the first encountered exception in
638
# its `.exception` attribute.
639
# Inspired by https://stackoverflow.com/a/33599967
640
class ErrorTrackingProcess(mp.Process):
641
    # Why no *args?
642
    #   py2 doesn't support def fn(x, *args, key=val, **kwargs)
643
    # Setting disable_stderr=True may generate a lot of unrelated error outputs
644
    # but could be helpful for debugging.
645
    def __init__(self, disable_stderr=True, **kwargs):
646
        super().__init__(**kwargs)
647
        self._pconn, self._cconn = mp.Pipe()
648
        self._exception = None
649
        self.disable_stderr = disable_stderr
650

651
    def run(self):
652
        set_faulthander_if_available()
653
        if self.disable_stderr:
654
            # Disable polluting stderr with errors that are supposed to happen.
655
            with open(os.devnull, "w") as devnull:
656
                os.dup2(devnull.fileno(), sys.stderr.fileno())
657
        try:
658
            super().run()
659
            self._cconn.send(None)
660
        except Exception:
661
            self._cconn.send(ExceptionWrapper(sys.exc_info()))
662
            raise
663

664
    def print_traces_of_all_threads(self):
665
        assert (
666
            self.is_alive()
667
        ), "can only use print_traces_of_all_threads if the process is alive"
668
        assert (
669
            not self.disable_stderr
670
        ), "do not disable stderr if you use print_traces_of_all_threads"
671
        # On platforms without `SIGUSR1`, `set_faulthander_if_available` sets
672
        # `faulthandler.enable()`, and `print_traces_of_all_threads` may kill
673
        # the process. So let's poll the exception first
674
        _ = self.exception
675
        print_traces_of_all_threads(self.pid)
676

677
    @property
678
    def exception(self):
679
        if self._pconn.poll():
680
            self._exception = self._pconn.recv()
681
        if self._exception is None:
682
            return None
683
        else:
684
            return self._exception.exc_type(self._exception.exc_msg)
685

686
    # ESRCH means that os.kill can't finds alive proc
687
    def send_signal(self, signum, ignore_ESRCH=False):
688
        try:
689
            os.kill(self.pid, signum)
690
        except OSError as e:
691
            if not ignore_ESRCH or e.errno != errno.ESRCH:
692
                raise
693

694

695
class ErrorDataset(Dataset):
696
    def __init__(self, size):
697
        self.size = size
698

699
    def __len__(self):
700
        return self.size
701

702

703
class SegfaultDataset(Dataset):
704
    def __init__(self, size):
705
        self.size = size
706

707
    def __getitem__(self, idx):
708
        return ctypes.string_at(0)
709

710
    def __len__(self):
711
        return self.size
712

713

714
class SleepDataset(Dataset):
715
    def __init__(self, size, sleep_sec):
716
        self.size = size
717
        self.sleep_sec = sleep_sec
718
        self.sleeped = False
719

720
    def __getitem__(self, idx):
721
        if not self.sleeped:
722
            time.sleep(self.sleep_sec)
723
            self.sleeped = True
724
        return idx
725

726
    def __len__(self):
727
        return self.size
728

729

730
class SeedDataset(Dataset):
731
    def __init__(self, size):
732
        self.size = size
733

734
    def __getitem__(self, idx):
735
        return torch.initial_seed()
736

737
    def __len__(self):
738
        return self.size
739

740

741
class WorkerSpecificIterableDataset(IterableDataset):
742
    def __init__(self, sizes_for_all_workers):
743
        self.sizes_for_all_workers = sizes_for_all_workers
744

745
    def __iter__(self):
746
        worker_info = torch.utils.data.get_worker_info()
747
        assert worker_info is not None
748
        return iter(range(self.sizes_for_all_workers[worker_info.id]))
749

750
    def __len__(self):
751
        return sum(self.sizes_for_all_workers)
752

753

754
# Inspired by https://stackoverflow.com/a/26703365
755
# If all workers will call `sync_once`, they will be blocked until all workers
756
# reach the call (i.e., acting like a barrier).
757
# This can be used to ensure that each worker at least processes one data.
758
class SynchronizedDataset(Dataset):
759
    def __init__(self, size, batch_size, num_workers):
760
        assert size >= num_workers * batch_size
761
        self.count = mp.Value("i", 0, lock=True)
762
        self.barrier = mp.Semaphore(0)
763
        self.num_workers = num_workers
764
        self.size = size
765

766
    def sync_once(self):
767
        with self.count.get_lock():
768
            self.count.value += 1
769
            if self.count.value == self.num_workers:
770
                self.barrier.release()
771
        self.barrier.acquire()
772
        self.barrier.release()
773

774
    def __getitem__(self, idx):
775
        raise NotImplementedError
776

777
    def __len__(self):
778
        return self.size
779

780

781
class EmptyTensorDataset(torch.utils.data.Dataset):
782
    def __init__(self, len):
783
        self.len = len
784

785
    def __len__(self):
786
        return self.len
787

788
    def __getitem__(self, any):
789
        return torch.empty(0)
790

791

792
class SynchronizedSeedDataset(SynchronizedDataset):
793
    def __getitem__(self, idx):
794
        self.sync_once()
795
        return torch.initial_seed()
796

797

798
def _test_timeout(persistent_workers):
799
    dataset = SleepDataset(10, 3)
800
    dataloader = DataLoader(
801
        dataset,
802
        batch_size=2,
803
        num_workers=2,
804
        timeout=1,
805
        persistent_workers=persistent_workers,
806
    )
807
    _ = next(iter(dataloader))
808

809

810
def _test_timeout_pin_memory(persistent_workers):
811
    dataset = SleepDataset(10, 3)
812
    dataloader = DataLoader(
813
        dataset,
814
        batch_size=2,
815
        num_workers=2,
816
        timeout=1,
817
        pin_memory=True,
818
        persistent_workers=persistent_workers,
819
    )
820
    _ = next(iter(dataloader))
821

822

823
def _test_large_sampler_indices(persistent_workers):
824
    # See
825
    #   test_large_sampler_indices
826
    #   https://github.com/pytorch/pytorch/issues/48666
827

828
    dataloader = torch.utils.data.DataLoader(
829
        EmptyTensorDataset(10000000),
830
        batch_size=40960,
831
        persistent_workers=persistent_workers,
832
        num_workers=1,
833
    )
834

835
    it = iter(dataloader)
836

837
    for x in it:
838
        assert x.numel() == 0
839
        raise RuntimeError("My Error")
840

841

842
def disable_stderr(worker_id):
843
    r"""
844
    Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
845
    from workers. Since worker signal handler prints with low-level write(),
846
    this has to be done on OS level via dup.
847

848
    This is used as worker_init_fn for test_segfault.
849
    """
850
    sys.stderr.flush()  # flush library buffers that dup2 knows nothing about
851
    # Can't use a with-block because otherwise the fd will be closed when this
852
    # function ends.
853
    with open(os.devnull, "w") as devnull:
854
        os.dup2(devnull.fileno(), sys.stderr.fileno())
855

856

857
def _test_segfault():
858
    dataset = SegfaultDataset(10)
859
    dataloader = DataLoader(
860
        dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr
861
    )
862
    _ = next(iter(dataloader))
863

864

865
def _test_no_segfault():
866
    dataset = [1, 2, 3]
867
    num_threads = torch.get_num_threads()
868
    if num_threads < 4:
869
        torch.set_num_threads(4)
870
    else:
871
        torch.set_num_threads(num_threads)
872
    mp_ctx = torch.multiprocessing.get_context(method="fork")
873
    dataloader = DataLoader(
874
        dataset,
875
        num_workers=1,
876
        worker_init_fn=disable_stderr,
877
        multiprocessing_context=mp_ctx,
878
    )
879
    _ = next(iter(dataloader))
880

881

882
class TestProperExitDataset(Dataset):
883
    def __init__(self, size, error_event):
884
        self.size = size
885
        self.error_event = error_event
886

887
    def __len__(self):
888
        return self.size
889

890
    def __getitem__(self, idx):
891
        worker_info = torch.utils.data.get_worker_info()
892
        if (
893
            self.error_event is not None
894
            and self.error_event.is_set()
895
            and worker_info.id == worker_info.num_workers - 1
896
        ):
897
            # only error in the last worker
898
            raise RuntimeError("Worker error")
899
        return torch.tensor([idx])
900

901

902
class TestProperExitIterableDataset(IterableDataset):
903
    def __init__(self, size, error_event):
904
        self.error_event = error_event
905
        self.size = size
906
        self.remaining = size
907

908
    def __len__(self):
909
        return self.size
910

911
    def __iter__(self):
912
        return self
913

914
    def __next__(self):
915
        worker_info = torch.utils.data.get_worker_info()
916
        if (
917
            self.error_event is not None
918
            and self.error_event.is_set()
919
            and worker_info.id == worker_info.num_workers - 1
920
        ):
921
            # only error in the last worker
922
            raise RuntimeError("Worker error")
923
        self.remaining -= 1
924
        if self.remaining < 0:
925
            raise StopIteration
926
        return torch.tensor(-1000)
927

928

929
# See TestDataLoader.test_proper_exit for usage
930
def _test_proper_exit(
931
    is_iterable_dataset,
932
    use_workers,
933
    pin_memory,
934
    exit_method,
935
    hold_iter_reference,
936
    loader_setup_event,
937
    tester_setup_event,
938
    persistent_workers,
939
):
940
    num_workers = 2 if use_workers else 0
941

942
    if exit_method == "worker_error" or exit_method == "worker_kill":
943
        assert use_workers is True
944

945
    if exit_method == "worker_error":
946
        worker_error_event = mp.Event()
947
    else:
948
        worker_error_event = None
949

950
    if is_iterable_dataset:
951
        ds = TestProperExitIterableDataset(7, worker_error_event)
952
    else:
953
        ds = TestProperExitDataset(12, worker_error_event)
954

955
    loader = DataLoader(
956
        ds,
957
        batch_size=1,
958
        shuffle=False,
959
        num_workers=num_workers,
960
        pin_memory=pin_memory,
961
        worker_init_fn=set_faulthander_if_available,
962
        persistent_workers=persistent_workers,
963
    )
964

965
    error_it = 2
966

967
    if use_workers:
968
        # 2 is the magical per-worker prefetch number...
969
        # FIXME: change this after the number becomes configurable.
970
        if is_iterable_dataset:
971
            assert len(ds) * num_workers > (error_it + 2 + 1)
972
        else:
973
            assert len(loader) > (error_it + 2 + 1) * num_workers
974
    else:
975
        if is_iterable_dataset:
976
            assert len(ds) > error_it + 1
977
        else:
978
            assert len(loader) > error_it + 1
979

980
    it = iter(loader)
981
    if use_workers:
982
        workers = it._workers
983

984
    def kill_pid(pid):
985
        psutil_p = psutil.Process(pid)
986
        psutil_p.kill()
987
        psutil_p.wait(JOIN_TIMEOUT)
988
        assert not psutil_p.is_running()
989

990
    for i, _ in enumerate(it):
991
        if i == 0:
992
            if not hold_iter_reference:
993
                del it
994
                del loader
995
            loader_setup_event.set()
996
            tester_setup_event.wait()
997
            # ensure that the workers are still alive
998
            if use_workers:
999
                for w in workers:
1000
                    assert w.is_alive()
1001
            if worker_error_event is not None:
1002
                worker_error_event.set()
1003

1004
        if i == error_it:
1005
            if exit_method == "loader_error":
1006
                raise RuntimeError("Loader error")
1007
            elif exit_method == "loader_kill":
1008
                kill_pid(os.getpid())
1009
            elif exit_method == "worker_kill":
1010
                kill_pid(workers[-1].pid)  # kill last worker
1011

1012
    if not hold_iter_reference:
1013
        # Tries to trigger the __del__ clean-up rather than the automatic
1014
        # exiting of daemonic children. Technically it should be automatically
1015
        # triggered, but I don't want to rely on the implementation detail of
1016
        # Python gc.
1017
        gc.collect()
1018

1019

1020
class TestWorkerInfoDataset(SynchronizedDataset):
1021
    def __getitem__(self, idx):
1022
        self.sync_once()
1023
        return torch.tensor(self.value)
1024

1025

1026
# Should be used as worker_init_fn with TestWorkerInfoDataset.
1027
# See _test_get_worker_info below for usage.
1028
def _test_worker_info_init_fn(worker_id):
1029
    worker_info = torch.utils.data.get_worker_info()
1030
    assert (
1031
        worker_id == worker_info.id
1032
    ), "worker_init_fn and worker_info should have consistent id"
1033
    assert (
1034
        worker_id < worker_info.num_workers
1035
    ), "worker_init_fn and worker_info should have valid id"
1036
    assert (
1037
        worker_info.seed == torch.initial_seed()
1038
    ), "worker_init_fn and worker_info should have consistent seed"
1039
    dataset = worker_info.dataset
1040
    assert isinstance(
1041
        dataset, TestWorkerInfoDataset
1042
    ), "worker_info should have correct dataset copy"
1043
    assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy"
1044
    # test that WorkerInfo attributes are read-only
1045
    try:
1046
        worker_info.id = 3999
1047
    except RuntimeError as e:
1048
        assert str(e) == "Cannot assign attributes to WorkerInfo objects"
1049
    try:
1050
        worker_info.a = 3
1051
    except RuntimeError as e:
1052
        assert str(e) == "Cannot assign attributes to WorkerInfo objects"
1053
    for k in ["id", "num_workers", "seed", "dataset"]:
1054
        assert f"{k}=" in repr(worker_info)
1055
    dataset.value = [worker_id, os.getpid()]
1056

1057

1058
def _test_get_worker_info():
1059
    # get_worker_info returns None in main proc
1060
    assert torch.utils.data.get_worker_info() is None
1061
    num_workers = 2
1062
    batch_size = 2
1063
    dataset = TestWorkerInfoDataset(6, batch_size, num_workers)
1064
    dataloader = DataLoader(
1065
        dataset,
1066
        batch_size=batch_size,
1067
        num_workers=num_workers,
1068
        worker_init_fn=_test_worker_info_init_fn,
1069
    )
1070
    it = iter(dataloader)
1071
    data = []
1072
    for d in it:
1073
        data.append(d)  # noqa: PERF402
1074
    worker_pids = [w.pid for w in it._workers]
1075
    data = torch.cat(data, 0)
1076
    for d in data:
1077
        # each `d` is a [worker_id, worker_pid] pair, which is set in
1078
        # _test_worker_info_init_fn
1079
        assert d[1] == worker_pids[d[0]]
1080
    # get_worker_info returns None in main proc after data loading
1081
    assert torch.utils.data.get_worker_info() is None
1082
    # main proc dataset was never assigned this attribute
1083
    assert not hasattr(dataset, "value")
1084
    try:
1085
        _ = dataset[0]
1086
    except AttributeError:
1087
        return
1088
    raise RuntimeError("Expected AttributeError")
1089

1090

1091
# test custom init function
1092
def init_fn(worker_id):
1093
    torch.manual_seed(12345)
1094

1095

1096
# used with test_error_in_init
1097
class ErrorIterableDataset(IterableDataset):
1098
    def __iter__(self):
1099
        raise RuntimeError("Error in __iter__")
1100

1101

1102
# used with test_error_in_init
1103
def error_worker_init_fn(_):
1104
    raise RuntimeError("Error in worker_init_fn")
1105

1106

1107
class BulkLoadingDataset(Dataset):
1108
    def __init__(self, length):
1109
        self.length = length
1110

1111
    def __getitem__(self, indices):
1112
        assert isinstance(indices, (list, tuple))
1113
        return torch.as_tensor(indices)
1114

1115
    def __len__(self):
1116
        return self.length
1117

1118

1119
class BulkLoadingSampler(torch.utils.data.Sampler):
1120
    def __init__(self, dataset, batch_size):
1121
        self.dataset = dataset
1122
        self.batch_size = batch_size
1123

1124
    def __iter__(self):
1125
        for x in torch.randperm(len(self.dataset)).split(self.batch_size):
1126
            yield x.tolist()
1127

1128
    def __len__(self):
1129
        return int(math.ceil(len(self.dataset) / float(self.batch_size)))
1130

1131

1132
class TestMultiEpochDataset(IterableDataset):
1133
    def __init__(self, length):
1134
        self.length = length
1135

1136
    def __iter__(self):
1137
        worker_info = torch.utils.data.get_worker_info()
1138
        assert worker_info is not None
1139
        worker_id = worker_info.id
1140
        for idx in range(self.length // worker_info.num_workers):
1141
            yield worker_id
1142

1143
    def __len__(self):
1144
        return self.length
1145

1146

1147
class CustomList(list):
1148
    pass
1149

1150

1151
class CustomDict(dict):
1152
    pass
1153

1154

1155
def row_processor(row):
1156
    return np.add(row, 1)
1157

1158

1159
def filter_len(row):
1160
    return len(row) == 4
1161

1162

1163
@unittest.skipIf(
1164
    TEST_WITH_TSAN,
1165
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
1166
    "fork is not supported. Dying (set die_after_fork=0 to override)",
1167
)
1168
@unittest.skipIf(
1169
    TEST_WITH_ASAN,
1170
    "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
1171
)
1172
class TestDataLoader(TestCase):
1173
    def setUp(self):
1174
        super().setUp()
1175
        self.data = torch.randn(100, 2, 3, 5)
1176
        self.labels = torch.randperm(50).repeat(2)
1177
        self.dataset = TensorDataset(self.data, self.labels)
1178
        self.persistent_workers = False
1179

1180
    def _get_data_loader(self, dataset, **kwargs):
1181
        persistent_workers = kwargs.get("persistent_workers", self.persistent_workers)
1182
        if persistent_workers and kwargs.get("num_workers", 0) == 0:
1183
            persistent_workers = False
1184
        kwargs["persistent_workers"] = persistent_workers
1185
        return DataLoader(dataset, **kwargs)
1186

1187
    def _test_sequential(self, loader):
1188
        batch_size = loader.batch_size
1189
        if batch_size is None:
1190
            for idx, (sample, target) in enumerate(loader):
1191
                self.assertEqual(sample, self.data[idx])
1192
                self.assertEqual(target, self.labels[idx])
1193
            self.assertEqual(idx, len(self.dataset) - 1)
1194
        else:
1195
            for i, (sample, target) in enumerate(loader):
1196
                idx = i * batch_size
1197
                self.assertEqual(sample, self.data[idx : idx + batch_size])
1198
                self.assertEqual(target, self.labels[idx : idx + batch_size])
1199
            self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
1200

1201
    def _test_shuffle(self, loader):
1202
        found_data = dict.fromkeys(range(self.data.size(0)), 0)
1203
        found_labels = dict.fromkeys(range(self.labels.size(0)), 0)
1204
        batch_size = loader.batch_size
1205
        if batch_size is None:
1206
            for i, (batch_samples, batch_targets) in enumerate(loader):
1207
                sample, target = (batch_samples, batch_targets)
1208
                for data_point_idx, data_point in enumerate(self.data):
1209
                    if data_point.eq(sample).all():
1210
                        self.assertFalse(found_data[data_point_idx])
1211
                        found_data[data_point_idx] += 1
1212
                        break
1213
                self.assertEqual(target, self.labels[data_point_idx])
1214
                found_labels[data_point_idx] += 1
1215
                self.assertEqual(sum(found_data.values()), (i + 1))
1216
                self.assertEqual(sum(found_labels.values()), (i + 1))
1217
            self.assertEqual(i, (len(self.dataset) - 1))
1218
        else:
1219
            for i, (batch_samples, batch_targets) in enumerate(loader):
1220
                for sample, target in zip(batch_samples, batch_targets):
1221
                    for data_point_idx, data_point in enumerate(self.data):
1222
                        if data_point.eq(sample).all():
1223
                            self.assertFalse(found_data[data_point_idx])
1224
                            found_data[data_point_idx] += 1
1225
                            break
1226
                    self.assertEqual(target, self.labels[data_point_idx])
1227
                    found_labels[data_point_idx] += 1
1228
                self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
1229
                self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
1230
            self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
1231

1232
    def _test_error(self, loader):
1233
        it = iter(loader)
1234
        errors = 0
1235
        while True:
1236
            try:
1237
                next(it)
1238
            except NotImplementedError:
1239
                errors += 1
1240
            except StopIteration:
1241
                self.assertEqual(
1242
                    errors, math.ceil(float(len(loader.dataset)) / loader.batch_size)
1243
                )
1244
                return
1245

1246
    def test_error_in_init(self):
1247
        for num_workers in [0, 2]:
1248
            loader = self._get_data_loader(
1249
                ErrorIterableDataset(), num_workers=num_workers
1250
            )
1251
            with self.assertRaisesRegex(RuntimeError, "Error in __iter__"):
1252
                list(iter(loader))
1253

1254
        loader = self._get_data_loader(
1255
            self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn
1256
        )
1257
        with self.assertRaisesRegex(RuntimeError, "Error in worker_init_fn"):
1258
            list(iter(loader))
1259

1260
    def test_typing(self):
1261
        from typing import List
1262

1263
        # Make sure there is no TypeError
1264

1265
        class SomeDatasetClass(Dataset[List[torch.Tensor]]):
1266
            pass
1267

1268
        def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
1269
            pass
1270

1271
    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
1272
    @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
1273
    def test_fd_limit_exceeded(self):
1274
        # See NOTE [ DataLoader on Linux and open files limit ]
1275
        import subprocess
1276

1277
        subprocess.check_output(
1278
            [
1279
                sys.executable,
1280
                "-c",
1281
                """\
1282
import torch
1283
import resource
1284
from torch.utils.data import DataLoader, IterableDataset
1285

1286
class RandomDataset(IterableDataset):
1287
    def __init__(self, len, size):
1288
        super(RandomDataset).__init__()
1289
        self.len = len
1290
        self.size = size
1291

1292
    def __iter__(self):
1293
        return self
1294

1295
    def __next__(self):
1296
        if self.len <= 0:
1297
            raise StopIteration
1298
        self.len -= 1
1299
        return torch.randn(self.size)
1300

1301
try:
1302
    keep_fds_alive = []
1303
    resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
1304
    for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
1305
                               num_workers=1):
1306
      random_t.max(dim=0)
1307
      keep_fds_alive.append(random_t)
1308
except RuntimeError as e:
1309
    assert "ulimit -n" in str(e)
1310
    assert "set_sharing_strategy" in str(e)
1311
""",
1312
            ]
1313
        )
1314

1315
    def test_invalid_assign_after_init(self):
1316
        dl = self._get_data_loader(self.dataset)
1317
        for attr in ("batch_size", "sampler", "batch_sampler", "drop_last", "dataset"):
1318

1319
            def fn():
1320
                setattr(dl, attr, {})
1321

1322
            self.assertRaises(ValueError, fn)
1323

1324
    def test_sequential_nonbatch(self):
1325
        self._test_sequential(self._get_data_loader(self.dataset, batch_size=None))
1326

1327
    def test_sequential_batch(self):
1328
        self._test_sequential(self._get_data_loader(self.dataset))
1329
        self._test_sequential(self._get_data_loader(self.dataset, batch_size=2))
1330

1331
    def test_bulk_loading_nobatch(self):
1332
        n = 35
1333
        bs = 4
1334
        ds = BulkLoadingDataset(n)
1335
        sampler = BulkLoadingSampler(ds, batch_size=4)
1336

1337
        for num_workers in [0, 4]:
1338
            dl = self._get_data_loader(
1339
                ds,
1340
                num_workers=num_workers,
1341
                batch_size=None,
1342
                sampler=sampler,
1343
                pin_memory=TEST_CUDA,
1344
            )
1345
            self.assertFalse(dl._auto_collation)
1346
            samples = list(dl)
1347
            self.assertEqual(samples[0].is_pinned(), TEST_CUDA)
1348
            self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n)))
1349

1350
    def test_growing_dataset(self):
1351
        dataset = [torch.ones(4) for _ in range(4)]
1352
        dataloader_seq = self._get_data_loader(dataset, shuffle=False)
1353
        dataloader_shuffle = self._get_data_loader(dataset, shuffle=True)
1354
        dataset.append(torch.ones(4))
1355
        self.assertEqual(len(dataloader_seq), 5)
1356
        self.assertEqual(len(dataloader_shuffle), 5)
1357

1358
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1359
    def test_sequential_pin_memory(self):
1360
        loader = self._get_data_loader(self.dataset, batch_size=2, pin_memory=True)
1361
        for input, target in loader:
1362
            self.assertTrue(input.is_pinned())
1363
            self.assertTrue(target.is_pinned())
1364

1365
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1366
    def test_multiple_dataloaders(self):
1367
        for multiprocessing_context in supported_multiprocessing_contexts:
1368
            loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1))
1369
            loader2_it = iter(
1370
                self._get_data_loader(
1371
                    self.dataset,
1372
                    num_workers=2,
1373
                    multiprocessing_context=multiprocessing_context,
1374
                )
1375
            )
1376
            next(loader1_it)
1377
            next(loader1_it)
1378
            next(loader2_it)
1379
            next(loader2_it)
1380
            next(loader1_it)
1381
            next(loader2_it)
1382
            del loader1_it
1383
            del loader2_it
1384

1385
    def test_segfault(self):
1386
        p = ErrorTrackingProcess(target=_test_segfault)
1387
        p.start()
1388
        p.join(JOIN_TIMEOUT)
1389
        try:
1390
            self.assertFalse(p.is_alive())
1391
            self.assertNotEqual(p.exitcode, 0)
1392
            if IS_WINDOWS:
1393
                self.assertIsInstance(p.exception, OSError)
1394
                self.assertRegex(str(p.exception), r"access violation reading ")
1395
            else:
1396
                self.assertIsInstance(p.exception, RuntimeError)
1397
                self.assertRegex(
1398
                    str(p.exception),
1399
                    r"DataLoader worker \(pid \d+\) is killed by signal: ",
1400
                )
1401
        finally:
1402
            p.terminate()
1403

1404
    # Tests if the child process forked by the DataLoader segfaults due to having more than 3 threads
1405
    # in the parent process after at least one set_num_threads invocation in the parent process.
1406
    # After forking, set_num_threads(1) in the child process entails handling some inherited data-structures
1407
    # of the Caffe2 thread-pool of the parent process, culminating in a segfault.
1408
    # Reference: https://github.com/pytorch/pytorch/issues/54752
1409
    @unittest.skipIf(IS_WINDOWS, "Needs fork")
1410
    def test_no_segfault(self):
1411
        p = ErrorTrackingProcess(target=_test_no_segfault)
1412
        p.start()
1413
        p.join(JOIN_TIMEOUT)
1414
        try:
1415
            self.assertFalse(p.is_alive())
1416
            if p.exception:
1417
                self.assertIsInstance(p.exception, RuntimeError)
1418
                self.assertRegex(
1419
                    str(p.exception),
1420
                    r"DataLoader worker \(pid \d+\) is killed by signal: ",
1421
                )
1422
                self.fail("Segfault occurred in worker process after fork")
1423
        finally:
1424
            p.terminate()
1425

1426
    def test_timeout(self):
1427
        if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
1428
            # This test runs in a subprocess, which can only initialize CUDA with spawn.
1429
            # _test_timeout_pin_memory with pin_memory=True initializes CUDA when the iterator is
1430
            # constructed.
1431
            targets = (_test_timeout, _test_timeout_pin_memory)
1432
        else:
1433
            targets = (_test_timeout,)
1434
        for target in targets:
1435
            p = ErrorTrackingProcess(target=target, args=(self.persistent_workers,))
1436
            p.start()
1437
            p.join(JOIN_TIMEOUT)
1438
            try:
1439
                self.assertFalse(p.is_alive())
1440
                self.assertNotEqual(p.exitcode, 0)
1441
                self.assertIsInstance(p.exception, RuntimeError)
1442
                self.assertRegex(
1443
                    str(p.exception), r"DataLoader timed out after \d+ seconds"
1444
                )
1445
            finally:
1446
                p.terminate()
1447

1448
    def test_large_sampler_indices(self):
1449
        # Test that the data loader cleanly exit when the process errors
1450
        #   1. having an reference to the iterator
1451
        #   2. using a sampler that yields big elements s.t. _index_queues putters block
1452
        #
1453
        # More context: https://github.com/pytorch/pytorch/issues/48666
1454

1455
        p = ErrorTrackingProcess(
1456
            target=_test_large_sampler_indices, args=(self.persistent_workers,)
1457
        )
1458
        p.start()
1459
        p.join(JOIN_TIMEOUT)
1460
        try:
1461
            self.assertFalse(p.is_alive())
1462
            self.assertNotEqual(p.exitcode, 0)
1463
            self.assertIsInstance(p.exception, RuntimeError)
1464
            self.assertRegex(str(p.exception), r"My Error")
1465
        finally:
1466
            p.terminate()
1467

1468
    def test_invalid_ctor_args_combinations(self):
1469
        # general
1470
        with self.assertRaisesRegex(
1471
            ValueError, "num_workers option should be non-negative"
1472
        ):
1473
            self._get_data_loader(self.dataset, num_workers=-1)
1474
        with self.assertRaisesRegex(
1475
            ValueError, "timeout option should be non-negative"
1476
        ):
1477
            self._get_data_loader(self.dataset, timeout=-1)
1478

1479
        # disable auto-batching
1480
        with self.assertRaisesRegex(
1481
            ValueError,
1482
            "batch_size=None option disables auto-batching and is mutually exclusive",
1483
        ):
1484
            self._get_data_loader(self.dataset, batch_size=None, drop_last=True)
1485

1486
        valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1]
1487
        with self.assertRaisesRegex(
1488
            ValueError, r"multi-process loading \(num_workers > 0\), but got"
1489
        ):
1490
            self._get_data_loader(
1491
                self.dataset, num_workers=0, multiprocessing_context=valid_ctx
1492
            )
1493
        with self.assertRaisesRegex(
1494
            ValueError, "should specify a valid start method in"
1495
        ):
1496
            self._get_data_loader(
1497
                self.dataset, num_workers=1, multiprocessing_context="bad"
1498
            )
1499
        with self.assertRaisesRegex(
1500
            TypeError, "multiprocessing_context option should be a valid context "
1501
        ):
1502
            self._get_data_loader(
1503
                self.dataset, num_workers=1, multiprocessing_context=object()
1504
            )
1505

1506
        # map-style
1507
        sampler = torch.utils.data.SequentialSampler(self.dataset)
1508
        batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False)
1509
        with self.assertRaisesRegex(
1510
            ValueError, "sampler option is mutually exclusive with shuffle"
1511
        ):
1512
            self._get_data_loader(
1513
                self.dataset, batch_size=11, sampler=sampler, shuffle=True
1514
            )
1515
        with self.assertRaisesRegex(
1516
            ValueError, "sampler option is mutually exclusive with shuffle"
1517
        ):
1518
            self._get_data_loader(
1519
                self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True
1520
            )
1521
        with self.assertRaisesRegex(
1522
            ValueError, "sampler option is mutually exclusive with shuffle"
1523
        ):
1524
            self._get_data_loader(
1525
                self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3
1526
            )
1527
        with self.assertRaisesRegex(
1528
            ValueError, "batch_sampler option is mutually exclusive with"
1529
        ):
1530
            self._get_data_loader(
1531
                self.dataset, batch_size=11, batch_sampler=batch_sampler
1532
            )
1533
        with self.assertRaisesRegex(
1534
            ValueError, "batch_sampler option is mutually exclusive with"
1535
        ):
1536
            self._get_data_loader(
1537
                self.dataset, shuffle=True, batch_sampler=batch_sampler
1538
            )
1539
        with self.assertRaisesRegex(
1540
            ValueError, "batch_sampler option is mutually exclusive with"
1541
        ):
1542
            self._get_data_loader(
1543
                self.dataset, drop_last=True, batch_sampler=batch_sampler
1544
            )
1545
        with self.assertRaisesRegex(
1546
            ValueError, "batch_sampler option is mutually exclusive with"
1547
        ):
1548
            self._get_data_loader(
1549
                self.dataset, drop_last=3, batch_sampler=batch_sampler
1550
            )
1551

1552
        # iterable-style
1553
        dataset = CountingIterableDataset(20)
1554
        with self.assertRaisesRegex(
1555
            ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
1556
        ):
1557
            self._get_data_loader(dataset, shuffle=True)
1558
        with self.assertRaisesRegex(
1559
            ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
1560
        ):
1561
            self._get_data_loader(dataset, shuffle=3)
1562
        with self.assertRaisesRegex(
1563
            ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
1564
        ):
1565
            self._get_data_loader(
1566
                dataset, sampler=torch.utils.data.SequentialSampler(dataset)
1567
            )
1568
        with self.assertRaisesRegex(
1569
            ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
1570
        ):
1571
            self._get_data_loader(dataset, sampler=3)
1572
        with self.assertRaisesRegex(
1573
            ValueError,
1574
            "DataLoader with IterableDataset: expected unspecified batch_sampler",
1575
        ):
1576
            self._get_data_loader(
1577
                dataset,
1578
                batch_sampler=torch.utils.data.BatchSampler(
1579
                    torch.utils.data.SequentialSampler(dataset), 3, False
1580
                ),
1581
            )
1582
        with self.assertRaisesRegex(
1583
            ValueError,
1584
            "DataLoader with IterableDataset: expected unspecified batch_sampler",
1585
        ):
1586
            self._get_data_loader(dataset, batch_sampler=3)
1587

1588
    def test_builtin_collection_conversion(self):
1589
        for coll_ty in (list, tuple):
1590
            for num_workers in (0, 1):
1591
                # map-style dataset
1592
                dataset = CountingDataset(20)
1593
                # no auto-batching
1594
                fetched = coll_ty(
1595
                    self._get_data_loader(
1596
                        dataset, batch_size=None, num_workers=num_workers
1597
                    )
1598
                )
1599
                self.assertEqual(fetched, coll_ty(range(20)))
1600
                # auto-batching
1601
                fetched = coll_ty(
1602
                    self._get_data_loader(
1603
                        dataset, batch_size=2, num_workers=num_workers
1604
                    )
1605
                )
1606
                self.assertEqual(
1607
                    fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
1608
                )
1609

1610
                # iterable-style dataset
1611
                dataset = CountingIterableDataset(20)
1612
                # no auto-batching
1613
                fetched = coll_ty(
1614
                    self._get_data_loader(
1615
                        dataset, batch_size=None, num_workers=num_workers
1616
                    )
1617
                )
1618
                self.assertEqual(fetched, coll_ty(range(20)))
1619
                # auto-batching
1620
                # this IterableDataset isn't configured for each worker, so for
1621
                # the equality test below to be valid, we cannot have more than 1 workers.
1622
                assert num_workers in [0, 1], "invalid test"
1623
                fetched = coll_ty(
1624
                    self._get_data_loader(
1625
                        dataset, batch_size=2, num_workers=num_workers
1626
                    )
1627
                )
1628
                self.assertEqual(
1629
                    fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
1630
                )
1631

1632
    def test_iterable_style_dataset(self):
1633
        # [no auto-batching] single process loading
1634
        dataset = CountingIterableDataset(20)
1635
        dataloader = self._get_data_loader(dataset, batch_size=None)
1636
        fetched = list(dataloader)
1637
        self.assertEqual(len(fetched), 20)
1638
        for i, d in enumerate(fetched):
1639
            # non-batched should not convert ints into tensors
1640
            self.assertIsInstance(d, int)
1641
            self.assertEqual(d, i)
1642
        # DataLoader should match len of the iterable-style dataset (if implemented)
1643
        self.assertEqual(len(dataloader), len(dataset))
1644

1645
        # [no auto-batching] multiprocessing loading
1646
        num_workers = 3
1647
        sizes_for_all_workers = [0, 4, 20]
1648
        expected = sorted(
1649
            functools.reduce(
1650
                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1651
            )
1652
        )
1653
        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1654
        for prefetch_factor in [2, 3, 4]:
1655
            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1656
            dataloader = self._get_data_loader(
1657
                dataset,
1658
                num_workers=num_workers,
1659
                batch_size=None,
1660
                worker_init_fn=set_faulthander_if_available,
1661
                prefetch_factor=prefetch_factor,
1662
            )
1663
            dataloader_iter = iter(dataloader)
1664
            fetched = sorted(dataloader_iter)
1665
            for a, b in zip(fetched, expected):
1666
                # non-batched should not convert ints into tensors
1667
                self.assertIsInstance(a, int)
1668
                self.assertEqual(a, b)
1669
            # DataLoader should match len of the iterable-style dataset (if implemented)
1670
            self.assertEqual(len(dataloader), len(dataset))
1671
            # When loading more than len(dataset) data, after accessing len(dataloader),
1672
            # we should get a warning. See NOTE [ IterableDataset and __len__ ].
1673
            dataset = CountingIterableDataset(20)
1674
            dataloader = self._get_data_loader(
1675
                dataset,
1676
                num_workers=num_workers,
1677
                worker_init_fn=set_faulthander_if_available,
1678
                prefetch_factor=prefetch_factor,
1679
            )
1680
            it = iter(dataloader)
1681
            for _ in range(40):
1682
                self.assertNotWarn(
1683
                    lambda: next(it), "Should not warn before accessing len(dataloader)"
1684
                )
1685
            self.assertEqual(len(dataloader), len(dataset))
1686
            self.assertEqual(len(dataloader), 20)
1687
            it = iter(dataloader)
1688
            for _ in range(20):
1689
                self.assertNotWarn(
1690
                    lambda: next(it), "Should not warn before exceeding length"
1691
                )
1692
            for _ in range(3):
1693
                with self.assertWarnsRegex(
1694
                    UserWarning,
1695
                    r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this",
1696
                    msg="Should always warn after exceeding length",
1697
                ):
1698
                    next(it)
1699
        # [no auto-batching] test that workers exit gracefully
1700
        workers = dataloader_iter._workers
1701
        del dataloader_iter
1702
        del dataloader
1703
        try:
1704
            for w in workers:
1705
                w.join(JOIN_TIMEOUT)
1706
                self.assertFalse(w.is_alive())
1707
                self.assertEqual(w.exitcode, 0)
1708
        finally:
1709
            for w in workers:
1710
                w.terminate()
1711

1712
        # [auto-batching] single process loading
1713
        dataset = CountingIterableDataset(20)
1714
        fetched = list(self._get_data_loader(dataset, batch_size=7))
1715
        self.assertEqual(len(fetched), 3)
1716
        self.assertEqual(fetched[0].tolist(), list(range(7)))
1717
        self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
1718
        self.assertEqual(fetched[2].tolist(), list(range(14, 20)))
1719

1720
        # [auto-batching] multiprocessing loading
1721
        num_workers = 3
1722
        sizes_for_all_workers = [0, 4, 20]
1723
        expected = sorted(
1724
            functools.reduce(
1725
                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1726
            )
1727
        )
1728
        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1729
        for prefetch_factor in [2, 3, 4]:
1730
            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1731
            # worker 0 should return 0 batches
1732
            # worker 1 should return 1 batches
1733
            # worker 2 should return 3 batches
1734
            dataloader = self._get_data_loader(
1735
                dataset,
1736
                num_workers=num_workers,
1737
                batch_size=7,
1738
                prefetch_factor=prefetch_factor,
1739
            )
1740
            dataloader_iter = iter(dataloader)
1741
            fetched = list(dataloader_iter)
1742
            self.assertEqual(len(fetched), 4)
1743
            fetched = {tuple(t.tolist()) for t in fetched}
1744
            self.assertEqual(
1745
                fetched,
1746
                {
1747
                    tuple(range(4)),
1748
                    tuple(range(7)),
1749
                    tuple(range(7, 14)),
1750
                    tuple(range(14, 20)),
1751
                },
1752
            )
1753

1754
            # [auto-batching] test that workers exit gracefully
1755
            workers = dataloader_iter._workers
1756
            del dataloader_iter
1757
            del dataloader
1758
            try:
1759
                for w in workers:
1760
                    w.join(JOIN_TIMEOUT)
1761
                    self.assertFalse(w.is_alive())
1762
                    self.assertEqual(w.exitcode, 0)
1763
            finally:
1764
                for w in workers:
1765
                    w.terminate()
1766
        # [auto-batching & drop_last] single process loading
1767
        dataset = CountingIterableDataset(20)
1768
        fetched = list(self._get_data_loader(dataset, batch_size=7, drop_last=True))
1769
        self.assertEqual(len(fetched), 2)
1770
        self.assertEqual(fetched[0].tolist(), list(range(7)))
1771
        self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
1772

1773
        # [auto-batching & drop_last] multiprocessing loading
1774
        num_workers = 3
1775
        sizes_for_all_workers = [0, 4, 20]
1776
        expected = sorted(
1777
            functools.reduce(
1778
                operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1779
            )
1780
        )
1781
        assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1782
        for prefetch_factor in [2, 3, 4]:
1783
            dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1784
            # worker 0 should return 0 batches
1785
            # worker 1 should return 1 batches
1786
            # worker 2 should return 3 batches
1787
            dataloader = self._get_data_loader(
1788
                dataset,
1789
                num_workers=num_workers,
1790
                batch_size=7,
1791
                drop_last=True,
1792
                worker_init_fn=set_faulthander_if_available,
1793
                prefetch_factor=prefetch_factor,
1794
            )
1795
            dataloader_iter = iter(dataloader)
1796
            fetched = list(dataloader_iter)
1797
            self.assertEqual(len(fetched), 2)
1798
            fetched = {tuple(t.tolist()) for t in fetched}
1799
            self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
1800

1801
            # [auto-batching & drop_last] test that workers exit gracefully
1802
            workers = dataloader_iter._workers
1803
            del dataloader_iter
1804
            del dataloader
1805
            try:
1806
                for w in workers:
1807
                    w.join(JOIN_TIMEOUT)
1808
                    self.assertFalse(w.is_alive())
1809
                    self.assertEqual(w.exitcode, 0)
1810
            finally:
1811
                for w in workers:
1812
                    w.terminate()
1813

1814
    def test_chain_iterable_style_dataset(self):
1815
        # chaining (concatenation)
1816
        dataset1 = CountingIterableDataset(20)
1817
        dataset2 = CountingIterableDataset(15)
1818
        expected = list(range(20)) + list(range(15))
1819
        for num_workers in [0, 1]:
1820
            for chained_dataset in [
1821
                dataset1 + dataset2,
1822
                ChainDataset([dataset1, dataset2]),
1823
            ]:
1824
                fetched = list(
1825
                    self._get_data_loader(chained_dataset, num_workers=num_workers)
1826
                )
1827
                self.assertEqual(len(fetched), len(expected))
1828
                for e, d in zip(expected, fetched):
1829
                    self.assertIsInstance(d, torch.Tensor)
1830
                    self.assertEqual(e, d)
1831

1832
        with self.assertRaisesRegex(
1833
            AssertionError, "ChainDataset only supports IterableDataset"
1834
        ):
1835
            list(iter(dataset1 + self.dataset))
1836

1837
        with self.assertRaisesRegex(
1838
            AssertionError, "ChainDataset only supports IterableDataset"
1839
        ):
1840
            list(iter(ChainDataset([dataset1, self.dataset])))
1841

1842
    @unittest.skipIf(IS_MACOS, "Not working on macos")
1843
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1844
    @skipIfRocm  # https://github.com/pytorch/pytorch/issues/90940
1845
    def test_multiprocessing_contexts(self):
1846
        reference = [
1847
            torch.arange(3),
1848
            torch.arange(3, 6),
1849
            torch.arange(6, 9),
1850
            torch.arange(9, 11),
1851
        ]
1852
        counting_ds_n = 11
1853
        dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA))
1854
        for ctx in supported_multiprocessing_contexts:
1855
            # windows and jetson devices don't support sharing cuda tensor; ROCm does not yet fully support IPC
1856
            if (
1857
                ctx in ["spawn", "forkserver"]
1858
                and TEST_CUDA
1859
                and not IS_WINDOWS
1860
                and not IS_JETSON
1861
            ):
1862
                ds_cls = CUDACountingDataset
1863
            else:
1864
                ds_cls = CountingDataset
1865
            self.assertEqual(
1866
                reference,
1867
                list(
1868
                    self._get_data_loader(
1869
                        ds_cls(counting_ds_n),
1870
                        multiprocessing_context=ctx,
1871
                        **dl_common_args,
1872
                    )
1873
                ),
1874
            )
1875
            if ctx is not None:
1876
                # test ctx object
1877
                ctx = mp.get_context(ctx)
1878
                self.assertEqual(
1879
                    reference,
1880
                    list(
1881
                        self._get_data_loader(
1882
                            ds_cls(counting_ds_n),
1883
                            multiprocessing_context=ctx,
1884
                            **dl_common_args,
1885
                        )
1886
                    ),
1887
                )
1888

1889
    def _test_multiprocessing_iterdatapipe(self, with_dill):
1890
        # Testing to make sure that function from global scope (e.g. imported from library) can be serialized
1891
        # and used with multiprocess DataLoader
1892

1893
        reference = [
1894
            torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
1895
            torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
1896
        ]
1897
        datapipe: IterDataPipe = IterableWrapper([[1, 2, 3, 4], [1, 2, 3, 4, 5, 6]])
1898
        datapipe = datapipe.map(row_processor)
1899
        datapipe = (
1900
            datapipe.filter(lambda row: len(row) == 4)
1901
            if with_dill
1902
            else datapipe.filter(filter_len)
1903
        )
1904

1905
        dl_common_args = dict(
1906
            num_workers=2, batch_size=2, shuffle=True, pin_memory=(not TEST_CUDA)
1907
        )
1908
        for ctx in supported_multiprocessing_contexts:
1909
            self.assertEqual(
1910
                reference,
1911
                [
1912
                    t.type(torch.int64)
1913
                    for t in self._get_data_loader(
1914
                        datapipe, multiprocessing_context=ctx, **dl_common_args
1915
                    )
1916
                ],
1917
            )
1918
            if ctx is not None:
1919
                # test ctx object
1920
                ctx = mp.get_context(ctx)
1921
                self.assertEqual(
1922
                    reference,
1923
                    [
1924
                        t.type(torch.int64)
1925
                        for t in self._get_data_loader(
1926
                            datapipe, multiprocessing_context=ctx, **dl_common_args
1927
                        )
1928
                    ],
1929
                )
1930

1931
    @skipIfNoNumpy
1932
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1933
    def test_multiprocessing_iterdatapipe(self):
1934
        self._test_multiprocessing_iterdatapipe(with_dill=False)
1935

1936
    @unittest.expectedFailure
1937
    @skipIfNoNumpy
1938
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1939
    @skipIfNoDill
1940
    def test_multiprocessing_iterdatapipe_with_dill(self):
1941
        self._test_multiprocessing_iterdatapipe(with_dill=True)
1942

1943
    def test_worker_seed(self):
1944
        num_workers = 6
1945
        batch_size = 1
1946
        dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
1947
        dataloader = self._get_data_loader(
1948
            dataset, batch_size=batch_size, num_workers=num_workers
1949
        )
1950
        seeds = set()
1951
        seeds.update(batch[0] for batch in dataloader)
1952
        self.assertEqual(len(seeds), num_workers)
1953

1954
    def test_worker_seed_reproducibility(self):
1955
        def get_dataloader():
1956
            return DataLoader(
1957
                dataset,
1958
                batch_size=batch_size,
1959
                num_workers=num_workers,
1960
                generator=torch.Generator().manual_seed(42),
1961
            )
1962

1963
        num_workers = 6
1964
        batch_size = 1
1965
        dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
1966
        self.assertEqual(
1967
            {int(batch) for batch in get_dataloader()},
1968
            {int(batch) for batch in get_dataloader()},
1969
        )
1970

1971
    def test_multi_epochs_reproducibility(self):
1972
        num_workers = 2
1973
        batch_size = 10
1974
        num_epochs = 3
1975

1976
        dataset = TestMultiEpochDataset(batch_size * num_workers)
1977
        dataloader = self._get_data_loader(
1978
            dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
1979
        )
1980

1981
        for ind in range(num_epochs):
1982
            for batch_idx, sample in enumerate(dataloader):
1983
                self.assertEqual(
1984
                    sample.tolist(), [batch_idx % num_workers] * batch_size
1985
                )
1986

1987
    def test_worker_init_fn(self):
1988
        dataset = SeedDataset(4)
1989
        dataloader = self._get_data_loader(
1990
            dataset, batch_size=2, num_workers=2, worker_init_fn=init_fn
1991
        )
1992
        for batch in dataloader:
1993
            self.assertEqual(12345, batch[0])
1994
            self.assertEqual(12345, batch[1])
1995

1996
    def test_get_worker_info(self):
1997
        p = ErrorTrackingProcess(target=_test_get_worker_info)
1998
        p.start()
1999
        p.join(JOIN_TIMEOUT)
2000
        try:
2001
            self.assertFalse(p.is_alive())
2002
            self.assertEqual(p.exitcode, 0)
2003
        finally:
2004
            p.terminate()
2005

2006
    def test_shuffle(self):
2007
        self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True))
2008

2009
    def test_shuffle_batch_none(self):
2010
        self._test_shuffle(DataLoader(self.dataset, batch_size=None, shuffle=True))
2011

2012
    def test_shuffle_batch(self):
2013
        self._test_shuffle(
2014
            self._get_data_loader(self.dataset, batch_size=2, shuffle=True)
2015
        )
2016

2017
    def test_shuffle_reproducibility(self):
2018
        for fn in (
2019
            lambda: DataLoader(
2020
                self.dataset,
2021
                shuffle=True,
2022
                num_workers=0,
2023
                generator=torch.Generator().manual_seed(42),
2024
            ),
2025
            lambda: DataLoader(
2026
                self.dataset,
2027
                shuffle=True,
2028
                num_workers=2,
2029
                generator=torch.Generator().manual_seed(42),
2030
            ),
2031
        ):
2032
            self.assertEqual(list(fn()), list(fn()))
2033

2034
    def test_sequential_workers(self):
2035
        self._test_sequential(self._get_data_loader(self.dataset, num_workers=4))
2036

2037
    def test_seqential_batch_workers(self):
2038
        self._test_sequential(
2039
            self._get_data_loader(self.dataset, batch_size=2, num_workers=4)
2040
        )
2041

2042
    def test_seqential_batch_workers_prefetch(self):
2043
        self._test_sequential(
2044
            DataLoader(self.dataset, batch_size=2, num_workers=4, prefetch_factor=3)
2045
        )
2046

2047
    def test_shuffle_workers(self):
2048
        self._test_shuffle(
2049
            self._get_data_loader(self.dataset, shuffle=True, num_workers=4)
2050
        )
2051

2052
    def test_shuffle_batch_workers(self):
2053
        self._test_shuffle(
2054
            self._get_data_loader(
2055
                self.dataset, batch_size=2, shuffle=True, num_workers=4
2056
            )
2057
        )
2058

2059
    def test_shuffle_batch_workers_prefetch(self):
2060
        self._test_shuffle(
2061
            DataLoader(
2062
                self.dataset,
2063
                batch_size=2,
2064
                shuffle=True,
2065
                num_workers=4,
2066
                prefetch_factor=3,
2067
            )
2068
        )
2069

2070
    def test_random_sampler(self):
2071
        from collections import Counter
2072

2073
        from torch.utils.data import RandomSampler
2074

2075
        def sample_stat(sampler, num_samples):
2076
            counts = Counter(sampler)
2077
            count_repeated = sum(val > 1 for val in counts.values())
2078
            return (
2079
                count_repeated,
2080
                min(counts.keys()),
2081
                max(counts.keys()),
2082
                sum(counts.values()),
2083
            )
2084

2085
        # test sample with replacement
2086
        n = len(self.dataset) + 1  # ensure at least one sample is drawn more than once
2087
        sampler_with_replacement = RandomSampler(
2088
            self.dataset, replacement=True, num_samples=n
2089
        )
2090
        count_repeated, minval, maxval, count_total = sample_stat(
2091
            sampler_with_replacement, n
2092
        )
2093
        self.assertTrue(count_repeated > 0)
2094
        self.assertTrue(minval >= 0)
2095
        self.assertTrue(maxval < len(self.dataset))
2096
        self.assertTrue(count_total == n)
2097

2098
        # test sample without replacement and without specified num_samples
2099
        sampler_without_replacement = RandomSampler(self.dataset)
2100
        count_repeated, minval, maxval, count_total = sample_stat(
2101
            sampler_without_replacement, len(self.dataset)
2102
        )
2103
        self.assertTrue(count_repeated == 0)
2104
        self.assertTrue(minval == 0)
2105
        self.assertTrue(maxval == len(self.dataset) - 1)
2106
        self.assertTrue(count_total == len(self.dataset))
2107

2108
        # test sample without replacement and with specified num_samples
2109
        n = len(self.dataset) * 2
2110
        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2111
        count_repeated, minval, maxval, count_total = sample_stat(
2112
            sampler_without_replacement, len(self.dataset)
2113
        )
2114
        self.assertTrue(count_repeated == len(self.dataset))
2115
        self.assertTrue(minval == 0)
2116
        self.assertTrue(maxval == len(self.dataset) - 1)
2117
        self.assertTrue(count_total == n)
2118

2119
        n = len(self.dataset) - 1
2120
        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2121
        count_repeated, minval, maxval, count_total = sample_stat(
2122
            sampler_without_replacement, len(self.dataset)
2123
        )
2124
        self.assertTrue(count_repeated == 0)
2125
        self.assertTrue(minval >= 0)
2126
        self.assertTrue(maxval < len(self.dataset))
2127
        self.assertTrue(count_total == n)
2128

2129
        n = len(self.dataset) + 1
2130
        sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2131
        count_repeated, minval, maxval, count_total = sample_stat(
2132
            sampler_without_replacement, len(self.dataset)
2133
        )
2134
        self.assertTrue(count_repeated == 1)
2135
        self.assertTrue(minval == 0)
2136
        self.assertTrue(maxval == len(self.dataset) - 1)
2137
        self.assertTrue(count_total == n)
2138

2139
        # raise error when replacement is non-boolean
2140
        with self.assertRaisesRegex(
2141
            TypeError, "replacement should be a boolean value, but got replacement=0"
2142
        ):
2143
            RandomSampler(self.dataset, replacement=0)
2144

2145
    def test_random_sampler_len_with_replacement(self):
2146
        from torch.utils.data import RandomSampler
2147

2148
        # add 5 extra samples
2149
        num_samples = len(self.dataset) + 5
2150
        sampler = RandomSampler(self.dataset, replacement=True, num_samples=num_samples)
2151
        # test len method
2152
        self.assertEqual(num_samples, len(sampler))
2153

2154
        # test with iteration
2155
        count_num_samples = sum(1 for _ in sampler)
2156
        self.assertEqual(num_samples, count_num_samples)
2157

2158
        # test with dataloader, batch_size = 1
2159
        batch_size = 1
2160
        count_num_samples_in_data_loader = len(
2161
            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2162
        )
2163
        self.assertEqual(num_samples, count_num_samples_in_data_loader)
2164

2165
        # test with dataloader, batch_size = 6
2166
        batch_size = 6
2167
        count_num_samples_in_data_loader = len(
2168
            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2169
        )
2170
        self.assertEqual(
2171
            int(math.ceil(float(num_samples) / batch_size)),
2172
            count_num_samples_in_data_loader,
2173
        )
2174

2175
    def test_random_sampler_len_without_replacement(self):
2176
        from torch.utils.data import RandomSampler
2177

2178
        # add 5 extra samples
2179
        num_samples = len(self.dataset) + 5
2180
        sampler = RandomSampler(
2181
            self.dataset, replacement=False, num_samples=num_samples
2182
        )
2183
        # test len method
2184
        self.assertEqual(num_samples, len(sampler))
2185

2186
        # test with iteration
2187
        count_num_samples = sum(1 for _ in sampler)
2188
        self.assertEqual(num_samples, count_num_samples)
2189

2190
        # test with dataloader, batch_size = 1
2191
        batch_size = 1
2192
        count_num_samples_in_data_loader = len(
2193
            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2194
        )
2195
        self.assertEqual(num_samples, count_num_samples_in_data_loader)
2196

2197
        # test with dataloader, batch_size = 6
2198
        batch_size = 6
2199
        count_num_samples_in_data_loader = len(
2200
            self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2201
        )
2202
        self.assertEqual(
2203
            num_samples // batch_size + (num_samples % batch_size > 0),
2204
            count_num_samples_in_data_loader,
2205
        )
2206

2207
    def test_distributed_sampler_invalid_rank(self):
2208
        from torch.utils.data.distributed import DistributedSampler
2209

2210
        dataset = torch.IntTensor(range(10))
2211
        with self.assertRaisesRegex(ValueError, "Invalid rank"):
2212
            sampler = DistributedSampler(dataset, 3, 3)
2213

2214
        with self.assertRaisesRegex(ValueError, "Invalid rank"):
2215
            sampler = DistributedSampler(dataset, 3, -1)
2216

2217
    def test_duplicating_data_with_drop_last(self):
2218
        from torch.utils.data.distributed import DistributedSampler
2219

2220
        num_processes = 4
2221
        num_batches = 9
2222
        data_set = torch.IntTensor(range(num_batches))
2223
        scanned_data = torch.IntTensor([])
2224
        for i in range(num_processes):
2225
            s = DistributedSampler(data_set, num_processes, i)
2226
            d_loader = self._get_data_loader(
2227
                data_set,
2228
                batch_size=int(num_batches / num_processes),
2229
                drop_last=True,
2230
                sampler=s,
2231
            )
2232
            for data in d_loader:
2233
                scanned_data = torch.cat((scanned_data, data), 0)
2234

2235
        self.assertEqual(scanned_data.size(), scanned_data.unique().size())
2236

2237
    def test_sampler_reproducibility(self):
2238
        from torch.utils.data import (
2239
            RandomSampler,
2240
            SubsetRandomSampler,
2241
            WeightedRandomSampler,
2242
        )
2243

2244
        weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6]
2245
        for fn in (
2246
            lambda: RandomSampler(
2247
                self.dataset,
2248
                num_samples=5,
2249
                replacement=True,
2250
                generator=torch.Generator().manual_seed(42),
2251
            ),
2252
            lambda: RandomSampler(
2253
                self.dataset,
2254
                replacement=False,
2255
                generator=torch.Generator().manual_seed(42),
2256
            ),
2257
            lambda: WeightedRandomSampler(
2258
                weights,
2259
                num_samples=5,
2260
                replacement=True,
2261
                generator=torch.Generator().manual_seed(42),
2262
            ),
2263
            lambda: WeightedRandomSampler(
2264
                weights,
2265
                num_samples=5,
2266
                replacement=False,
2267
                generator=torch.Generator().manual_seed(42),
2268
            ),
2269
            lambda: SubsetRandomSampler(
2270
                range(10), generator=torch.Generator().manual_seed(42)
2271
            ),
2272
        ):
2273
            self.assertEqual(list(fn()), list(fn()))
2274

2275
        for sampler in (
2276
            RandomSampler(self.dataset, num_samples=5, replacement=True),
2277
            RandomSampler(self.dataset, replacement=False),
2278
            WeightedRandomSampler(weights, num_samples=5, replacement=True),
2279
            WeightedRandomSampler(weights, num_samples=5, replacement=False),
2280
            SubsetRandomSampler(range(10)),
2281
        ):
2282
            torch.manual_seed(0)
2283
            l1 = list(sampler) + list(sampler)
2284

2285
            torch.manual_seed(0)
2286
            l2 = list(sampler) + list(sampler)
2287
            self.assertEqual(l1, l2)
2288

2289
            its = (iter(sampler), iter(sampler))
2290
            ls = ([], [])
2291
            for idx in range(len(sampler)):
2292
                for i in range(2):
2293
                    if idx == 0:
2294
                        torch.manual_seed(0)
2295
                    ls[i].append(next(its[i]))
2296
            self.assertEqual(ls[0], ls[1])
2297

2298
    def _test_sampler(self, **kwargs):
2299
        indices = range(2, 12)  # using a regular iterable
2300
        dl = self._get_data_loader(
2301
            self.dataset, sampler=indices, batch_size=2, **kwargs
2302
        )
2303
        self.assertEqual(len(dl), 5)
2304
        for i, (input, _target) in enumerate(dl):
2305
            self.assertEqual(len(input), 2)
2306
            self.assertEqual(input, self.data[i * 2 + 2 : i * 2 + 4])
2307

2308
    def test_sampler(self):
2309
        self._test_sampler()
2310
        self._test_sampler(num_workers=4)
2311
        if not NO_MULTIPROCESSING_SPAWN:
2312
            self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
2313

2314
    def _test_batch_sampler(self, **kwargs):
2315
        # [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
2316
        batches = []  # using a regular iterable
2317
        for i in range(0, 20, 5):
2318
            batches.append(tuple(range(i, i + 2)))
2319
            batches.append(tuple(range(i + 2, i + 5)))
2320

2321
        dl = self._get_data_loader(self.dataset, batch_sampler=batches, **kwargs)
2322
        self.assertEqual(len(dl), 8)
2323
        for i, (input, _target) in enumerate(dl):
2324
            if i % 2 == 0:
2325
                offset = i * 5 // 2
2326
                self.assertEqual(len(input), 2)
2327
                self.assertEqual(input, self.data[offset : offset + 2])
2328
            else:
2329
                offset = i * 5 // 2
2330
                self.assertEqual(len(input), 3)
2331
                self.assertEqual(input, self.data[offset : offset + 3])
2332

2333
    def test_batch_sampler(self):
2334
        self._test_batch_sampler()
2335
        self._test_batch_sampler(num_workers=4)
2336
        if not NO_MULTIPROCESSING_SPAWN:
2337
            self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
2338

2339
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
2340
    def test_shuffle_pin_memory(self):
2341
        loader = self._get_data_loader(
2342
            self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
2343
        )
2344
        for input, target in loader:
2345
            self.assertTrue(input.is_pinned())
2346
            self.assertTrue(target.is_pinned())
2347

2348
    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2349
    def test_numpy(self):
2350
        import numpy as np
2351

2352
        class TestDataset(torch.utils.data.Dataset):
2353
            def __getitem__(self, i):
2354
                return np.ones((2, 3, 4)) * i
2355

2356
            def __len__(self):
2357
                return 1000
2358

2359
        loader = self._get_data_loader(TestDataset(), batch_size=12)
2360
        batch = next(iter(loader))
2361
        self.assertIsInstance(batch, torch.DoubleTensor)
2362
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
2363

2364
    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2365
    def test_numpy_gen_state(self):
2366
        from torch.utils.data._utils.worker import _generate_state
2367

2368
        # Using NumPy generated states as the reference to test `_generate_state`
2369
        # having the same result.
2370
        # Test case: ((worker_id, base_seed), expected_state)
2371
        test_cases = [
2372
            (
2373
                (4, 13434589827475259383),
2374
                (2884386318, 1088094898, 3523808998, 3860348662),
2375
            ),
2376
            ((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)),
2377
            (
2378
                (10, 978296274032934101),
2379
                (1759791917, 3550927336, 1225977135, 1036538043),
2380
            ),
2381
            (
2382
                (12, 11868770762134256968),
2383
                (3974661794, 3331131333, 3630387033, 2885815368),
2384
            ),
2385
            (
2386
                (9, 15378787925219019706),
2387
                (3815056996, 3162224466, 2735102421, 3190253477),
2388
            ),
2389
            ((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)),
2390
            (
2391
                (15, 14617792358407278405),
2392
                (3402479508, 1588702753, 1169536393, 3675067356),
2393
            ),
2394
            (
2395
                (9, 17363320784006640087),
2396
                (957989458, 2518334477, 1421725660, 3086155459),
2397
            ),
2398
            (
2399
                (12, 480002904169484764),
2400
                (2732851467, 1762620729, 4055801988, 1277640511),
2401
            ),
2402
            (
2403
                (15, 16803975943592702950),
2404
                (3479415043, 4022359553, 295994005, 3358606349),
2405
            ),
2406
            (
2407
                (9, 11704776406047813044),
2408
                (1968928009, 710113752, 2442656196, 1587420279),
2409
            ),
2410
            (
2411
                (10, 16357891985431864516),
2412
                (1271733898, 4197047399, 3727213786, 2338547348),
2413
            ),
2414
            (
2415
                (2, 17423369006318065007),
2416
                (544294336, 1911284083, 3299147734, 3231058347),
2417
            ),
2418
            ((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)),
2419
            ((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)),
2420
            (
2421
                (6, 6269787272229682235),
2422
                (2548857855, 1216457374, 1012973562, 2999759647),
2423
            ),
2424
        ]
2425

2426
        for (worker_id, base_seed), exp in test_cases:
2427
            self.assertEqual(exp, _generate_state(base_seed, worker_id))
2428

2429
    def test_error(self):
2430
        self._test_error(
2431
            self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True)
2432
        )
2433

2434
    def test_error_workers(self):
2435
        self._test_error(
2436
            self._get_data_loader(
2437
                ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4
2438
            )
2439
        )
2440

2441
    @unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
2442
    def test_partial_workers(self):
2443
        r"""Check that workers exit even if the iterator is not exhausted."""
2444
        if TEST_CUDA:
2445
            pin_memory_configs = (True, False)
2446
        else:
2447
            pin_memory_configs = (False,)
2448

2449
        for pin_memory in pin_memory_configs:
2450
            loader = iter(
2451
                self._get_data_loader(
2452
                    self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory
2453
                )
2454
            )
2455
            workers = loader._workers
2456
            if pin_memory:
2457
                pin_memory_thread = loader._pin_memory_thread
2458
            for i, _ in enumerate(loader):
2459
                if i == 10:
2460
                    break
2461
            assert i == 10
2462
            del loader
2463
            for w in workers:
2464
                w.join(JOIN_TIMEOUT)
2465
                self.assertFalse(w.is_alive(), "subprocess not terminated")
2466
            if pin_memory:
2467
                pin_memory_thread.join(JOIN_TIMEOUT)
2468
                self.assertFalse(pin_memory_thread.is_alive())
2469

2470
    # Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065
2471
    @skipIfRocm
2472
    @unittest.skipIf(not HAS_PSUTIL, "psutil not found")
2473
    @slowTest
2474
    def test_proper_exit(self):
2475
        (
2476
            r"""There might be ConnectionResetError or leaked semaphore warning """
2477
            r"""(due to dirty process exit), but they are all safe to ignore"""
2478
        )
2479

2480
        # TODO: test the case where the pin_memory_thread triggers an
2481
        #       error/fatal signal. I haven't found out how to properly do that.
2482

2483
        for (
2484
            is_iterable_dataset,
2485
            use_workers,
2486
            pin_memory,
2487
            hold_iter_reference,
2488
        ) in itertools.product([True, False], repeat=4):
2489
            # `hold_iter_reference` specifies whether we hold a reference to the
2490
            # iterator. This is interesting because Python3 error traces holds a
2491
            # reference to the frames, which hold references to all the local
2492
            # variables including the iterator, and then the iterator dtor may
2493
            # not be called before process end. It is important to see that the
2494
            # processes still exit in both cases.
2495

2496
            if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS):
2497
                # This test runs in a subprocess, which can only initialize CUDA with spawn.
2498
                # DataLoader with pin_memory=True initializes CUDA when its iterator is constructed.
2499
                # For windows, pin_memory sometimes causes CUDA oom.
2500
                continue
2501

2502
            # `exit_method` controls the way the loader process ends.
2503
            #   - `*_kill` means that `*` is killed by OS.
2504
            #   - `*_error` means that `*` raises an error.
2505
            #   - `None` means that no error happens.
2506
            # In all cases, all processes should end properly.
2507
            if use_workers:
2508
                # TODO: Fix test for 'loader_kill' that would cause running out of shared memory.
2509
                # Killing loader process would prevent DataLoader iterator clean up all queues
2510
                # and worker processes
2511
                exit_methods = [None, "loader_error", "worker_error", "worker_kill"]
2512
                persistent_workers = self.persistent_workers
2513
            else:
2514
                exit_methods = [None, "loader_error", "loader_kill"]
2515
                persistent_workers = False
2516

2517
            for exit_method in exit_methods:
2518
                if exit_method == "worker_kill":
2519
                    # FIXME: This sometimes hangs. See #16608.
2520
                    continue
2521

2522
                desc = []
2523
                desc.append(f"is_iterable_dataset={is_iterable_dataset}")
2524
                desc.append(f"use_workers={use_workers}")
2525
                desc.append(f"pin_memory={pin_memory}")
2526
                desc.append(f"hold_iter_reference={hold_iter_reference}")
2527
                desc.append(f"exit_method={exit_method}")
2528
                desc = "test_proper_exit with " + ", ".join(desc)
2529

2530
                # Event that the loader process uses to signal testing process
2531
                # that various things are setup, including that the worker pids
2532
                # are specified in `worker_pids` array.
2533
                loader_setup_event = mp.Event()
2534

2535
                # Event that this process has finished setting up, and the
2536
                # loader process can now proceed to trigger error events or
2537
                # finish normally.
2538
                tester_setup_event = mp.Event()
2539

2540
                loader_p = ErrorTrackingProcess(
2541
                    target=_test_proper_exit,
2542
                    args=(
2543
                        is_iterable_dataset,
2544
                        use_workers,
2545
                        pin_memory,
2546
                        exit_method,
2547
                        hold_iter_reference,
2548
                        loader_setup_event,
2549
                        tester_setup_event,
2550
                        persistent_workers,
2551
                    ),
2552
                    disable_stderr=False,
2553
                )
2554
                loader_p.start()
2555
                loader_psutil_p = psutil.Process(loader_p.pid)
2556

2557
                # Wait for loader process to set everything up, e.g., starting
2558
                # workers.
2559
                loader_setup_event.wait(timeout=JOIN_TIMEOUT)
2560
                if not loader_setup_event.is_set():
2561
                    fail_msg = (
2562
                        desc + ": loader process failed to setup within given time"
2563
                    )
2564
                    if loader_p.exception is not None:
2565
                        fail_msg += f", and had exception {loader_p.exception}"
2566
                    elif not loader_p.is_alive():
2567
                        fail_msg += f", and exited with code {loader_p.exitcode} but had no exception"
2568
                    else:
2569
                        fail_msg += ", and is still alive."
2570
                    if loader_p.is_alive():
2571
                        # this may kill the process, needs to run after the above lines
2572
                        loader_p.print_traces_of_all_threads()
2573
                    self.fail(fail_msg)
2574

2575
                # We are certain that the workers have started now.
2576
                worker_psutil_ps = loader_psutil_p.children()
2577

2578
                def fail(reason):
2579
                    report_psutil_attrs = [
2580
                        "pid",
2581
                        "name",
2582
                        "cpu_times",
2583
                        "io_counters",
2584
                        "memory_full_info",
2585
                        "num_ctx_switches",
2586
                        "open_files",
2587
                        "threads",
2588
                        "status",
2589
                        "nice",
2590
                        "ionice",
2591
                    ]
2592
                    if reason is None:
2593
                        err_msg = desc
2594
                    else:
2595
                        err_msg = f"{desc}: {reason}"
2596
                    err_msg += "\nLoader info:\n\t"
2597
                    if loader_psutil_p.is_running():
2598
                        err_msg += str(
2599
                            loader_psutil_p.as_dict(attrs=report_psutil_attrs)
2600
                        )
2601
                        # this may kill the process, needs to run after the above line
2602
                        loader_p.print_traces_of_all_threads()
2603
                    else:
2604
                        err_msg += f"exited with code {loader_p.exitcode}"
2605
                    if use_workers:
2606
                        err_msg += "\nWorker(s) info:"
2607
                        for idx, worker_psutil_p in enumerate(worker_psutil_ps):
2608
                            err_msg += f"\n\tWorker {idx}:\n\t\t"
2609
                            if worker_psutil_p.is_running():
2610
                                err_msg += str(
2611
                                    worker_psutil_p.as_dict(attrs=report_psutil_attrs)
2612
                                )
2613
                                # this may kill the process, needs to run after the above line
2614
                                print_traces_of_all_threads(worker_psutil_p.pid)
2615
                            else:
2616
                                err_msg += "exited with unknown code"
2617
                    self.fail(err_msg)
2618

2619
                tester_setup_event.set()
2620

2621
                try:
2622
                    loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
2623
                    if loader_p.is_alive():
2624
                        fail_reason = "loader process did not terminate"
2625
                        if loader_p.exception is not None:
2626
                            fail(
2627
                                fail_reason
2628
                                + f", and had exception {loader_p.exception}"
2629
                            )
2630
                        else:
2631
                            fail(fail_reason + ", and had no exception")
2632
                    _, alive = psutil.wait_procs(
2633
                        worker_psutil_ps,
2634
                        timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT),
2635
                    )
2636
                    if len(alive) > 0:
2637
                        fail(
2638
                            "worker process (pid(s) {}) did not terminate".format(
2639
                                ", ".join(str(p.pid) for p in alive)
2640
                            )
2641
                        )
2642
                    if exit_method is None:
2643
                        if loader_p.exitcode != 0:
2644
                            fail(
2645
                                f"loader process had nonzero exitcode {loader_p.exitcode}"
2646
                            )
2647
                    else:
2648
                        if loader_p.exitcode == 0:
2649
                            fail("loader process had zero exitcode")
2650
                        if exit_method == "loader_error":
2651
                            if not isinstance(
2652
                                loader_p.exception, RuntimeError
2653
                            ) or "Loader error" not in str(loader_p.exception):
2654
                                fail(
2655
                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2656
                                )
2657
                        elif exit_method == "worker_kill":
2658
                            if isinstance(loader_p.exception, RuntimeError):
2659
                                if "DataLoader worker (pid" not in str(
2660
                                    loader_p.exception
2661
                                ):
2662
                                    fail(
2663
                                        f"loader process did not raise expected exception, but had {loader_p.exception}"
2664
                                    )
2665
                            elif isinstance(loader_p.exception, ConnectionRefusedError):
2666
                                # Sometimes, when the worker is being killed and is freeing its
2667
                                # resources, the unpickling in loader process will be met an
2668
                                # a `ConnectionRefusedError` as it can not open a socket to receive
2669
                                # resource. In such cases, the worker may not have fully exited,
2670
                                # and the loader can't know this via `is_alive` check or `SIGCHLD`
2671
                                # handler. So we permit this as an allowed error as well.
2672
                                # After all, we are happy as long as it terminates.
2673
                                pass
2674
                            else:
2675
                                fail(
2676
                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2677
                                )
2678
                        elif exit_method == "worker_error":
2679
                            if not isinstance(
2680
                                loader_p.exception, RuntimeError
2681
                            ) or "Worker error" not in str(loader_p.exception):
2682
                                fail(
2683
                                    f"loader process did not raise expected exception, but had {loader_p.exception}"
2684
                                )
2685
                finally:
2686
                    loader_p.terminate()
2687

2688
    def test_len(self):
2689
        def check_len(dl, expected):
2690
            self.assertEqual(len(dl), expected)
2691
            n = 0
2692
            for _ in dl:
2693
                n += 1
2694
            self.assertEqual(n, expected)
2695

2696
        check_len(self.dataset, 100)
2697
        check_len(self._get_data_loader(self.dataset, batch_size=2), 50)
2698
        check_len(self._get_data_loader(self.dataset, batch_size=3), 34)
2699

2700
    def test_iterabledataset_len(self):
2701
        class IterableDataset(torch.utils.data.IterableDataset):
2702
            def __len__(self):
2703
                return 10
2704

2705
            def __iter__(self):
2706
                return iter(range(10))
2707

2708
        iterable_loader = DataLoader(IterableDataset(), batch_size=1)
2709
        self.assertEqual(len(iterable_loader), 10)
2710
        iterable_loader = DataLoader(IterableDataset(), batch_size=1, drop_last=True)
2711
        self.assertEqual(len(iterable_loader), 10)
2712

2713
        iterable_loader = DataLoader(IterableDataset(), batch_size=2)
2714
        self.assertEqual(len(iterable_loader), 5)
2715
        iterable_loader = DataLoader(IterableDataset(), batch_size=2, drop_last=True)
2716
        self.assertEqual(len(iterable_loader), 5)
2717

2718
        iterable_loader = DataLoader(IterableDataset(), batch_size=3)
2719
        self.assertEqual(len(iterable_loader), 4)
2720
        iterable_loader = DataLoader(IterableDataset(), batch_size=3, drop_last=True)
2721
        self.assertEqual(len(iterable_loader), 3)
2722

2723
    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2724
    def test_numpy_scalars(self):
2725
        import numpy as np
2726

2727
        class ScalarDataset(torch.utils.data.Dataset):
2728
            def __init__(self, dtype):
2729
                self.dtype = dtype
2730

2731
            def __getitem__(self, i):
2732
                return self.dtype()
2733

2734
            def __len__(self):
2735
                return 4
2736

2737
        dtypes = {
2738
            np.float64: torch.DoubleTensor,
2739
            np.float32: torch.FloatTensor,
2740
            np.float16: torch.HalfTensor,
2741
            np.int64: torch.LongTensor,
2742
            np.int32: torch.IntTensor,
2743
            np.int16: torch.ShortTensor,
2744
            np.int8: torch.CharTensor,
2745
            np.uint8: torch.ByteTensor,
2746
        }
2747
        for dt, tt in dtypes.items():
2748
            dset = ScalarDataset(dt)
2749
            loader = self._get_data_loader(dset, batch_size=2)
2750
            batch = next(iter(loader))
2751
            self.assertIsInstance(batch, tt)
2752

2753
    def test_default_convert_mapping_keep_type(self):
2754
        data = CustomDict({"a": 1, "b": 2})
2755
        converted = _utils.collate.default_convert(data)
2756

2757
        self.assertEqual(converted, data)
2758

2759
    def test_default_convert_sequence_keep_type(self):
2760
        data = CustomList([1, 2, 3])
2761
        converted = _utils.collate.default_convert(data)
2762

2763
        self.assertEqual(converted, data)
2764

2765
    def test_default_convert_sequence_dont_keep_type(self):
2766
        data = range(2)
2767
        converted = _utils.collate.default_convert(data)
2768

2769
        self.assertEqual(converted, [0, 1])
2770

2771
    def test_default_collate_dtype(self):
2772
        arr = [1, 2, -1]
2773
        collated = _utils.collate.default_collate(arr)
2774
        self.assertEqual(collated, torch.tensor(arr))
2775
        self.assertEqual(collated.dtype, torch.int64)
2776

2777
        arr = [1.1, 2.3, -0.9]
2778
        collated = _utils.collate.default_collate(arr)
2779
        self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64))
2780

2781
        arr = [True, False]
2782
        collated = _utils.collate.default_collate(arr)
2783
        self.assertEqual(collated, torch.tensor(arr))
2784
        self.assertEqual(collated.dtype, torch.bool)
2785

2786
        # Should be a no-op
2787
        arr = ["a", "b", "c"]
2788
        self.assertEqual(arr, _utils.collate.default_collate(arr))
2789

2790
    def test_default_collate_mapping_keep_type(self):
2791
        batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})]
2792
        collated = _utils.collate.default_collate(batch)
2793

2794
        expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])})
2795
        self.assertEqual(collated, expected)
2796

2797
    def test_default_collate_sequence_keep_type(self):
2798
        batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])]
2799
        collated = _utils.collate.default_collate(batch)
2800

2801
        expected = CustomList(
2802
            [
2803
                torch.tensor([1, 4]),
2804
                torch.tensor([2, 5]),
2805
                torch.tensor([3, 6]),
2806
            ]
2807
        )
2808
        self.assertEqual(collated, expected)
2809

2810
    def test_default_collate_sequence_dont_keep_type(self):
2811
        batch = [range(2), range(2)]
2812
        collated = _utils.collate.default_collate(batch)
2813

2814
        self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])])
2815

2816
    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2817
    def test_default_collate_bad_numpy_types(self):
2818
        import numpy as np
2819

2820
        # Should be a no-op
2821
        arr = np.array(["a", "b", "c"])
2822
        self.assertEqual(arr, _utils.collate.default_collate(arr))
2823

2824
        arr = np.array([[["a", "b", "c"]]])
2825
        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2826

2827
        arr = np.array([object(), object(), object()])
2828
        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2829

2830
        arr = np.array([[[object(), object(), object()]]])
2831
        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2832

2833
    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2834
    def test_default_collate_numpy_memmap(self):
2835
        import numpy as np
2836

2837
        with tempfile.TemporaryFile() as f:
2838
            arr = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
2839
            arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape)
2840
            arr_memmap[:] = arr[:]
2841
            arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape)
2842
            tensor = _utils.collate.default_collate(list(arr_new))
2843

2844
        self.assertTrue(
2845
            (tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item()
2846
        )
2847

2848
    def test_default_collate_bad_sequence_type(self):
2849
        batch = [["X"], ["X", "X"]]
2850
        self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch))
2851
        self.assertRaises(
2852
            RuntimeError, lambda: _utils.collate.default_collate(batch[::-1])
2853
        )
2854

2855
    @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2856
    def test_default_collate_shared_tensor(self):
2857
        import numpy as np
2858

2859
        t_in = torch.zeros(1)
2860
        n_in = np.zeros(1)
2861

2862
        self.assertEqual(t_in.is_shared(), False)
2863

2864
        self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
2865
        self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
2866

2867
        # FIXME: fix the following hack that makes `default_collate` believe
2868
        #        that it is in a worker process (since it tests
2869
        #        `get_worker_info() != None`), even though it is not.
2870
        old = _utils.worker._worker_info
2871
        try:
2872
            _utils.worker._worker_info = "x"
2873
            self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
2874
            self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
2875
        finally:
2876
            _utils.worker._worker_info = old
2877

2878
    def test_excessive_thread_creation_warning(self):
2879
        with self.assertWarnsRegex(
2880
            UserWarning,
2881
            r"excessive worker creation might get DataLoader running slow or even freeze",
2882
        ):
2883
            dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
2884

2885

2886
class TestDataLoaderDeviceType(TestCase):
2887
    @parametrize(
2888
        "context",
2889
        [ctx for ctx in supported_multiprocessing_contexts if ctx is not None],
2890
    )
2891
    @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
2892
    def test_nested_tensor_multiprocessing(self, device, context):
2893
        # The 'fork' multiprocessing context doesn't work for CUDA so skip it
2894
        if "cuda" in device and context == "fork":
2895
            # TODO: Skip this better in a better way when the test framework allows
2896
            return
2897

2898
        dataset = [
2899
            torch.nested.nested_tensor([torch.randn(5)], device=device)
2900
            for _ in range(10)
2901
        ]
2902

2903
        pin_memory_settings = [False]
2904
        if device == "cpu" and torch.cuda.is_available():
2905
            pin_memory_settings.append(True)
2906

2907
        for pin_memory in pin_memory_settings:
2908
            loader = torch.utils.data.DataLoader(
2909
                dataset,
2910
                batch_size=1,
2911
                num_workers=4,
2912
                collate_fn=_clone_collate,
2913
                pin_memory=pin_memory,
2914
                multiprocessing_context=context,
2915
            )
2916

2917
            for i, batch in enumerate(loader):
2918
                self.assertEqual(batch[0], dataset[i])
2919

2920
        # Error case: default collate_fn doesn't currently support batches of nested tensors.
2921
        # Following the current semantics, we'd need to stack them, which isn't possible atm.
2922
        with self.assertRaisesRegex(
2923
            RuntimeError, "not currently supported by the default collate_fn"
2924
        ):
2925
            loader = torch.utils.data.DataLoader(
2926
                dataset,
2927
                batch_size=1,
2928
                num_workers=4,
2929
                multiprocessing_context=context,
2930
            )
2931

2932
            next(iter(loader))
2933

2934

2935
class IntegrationTestDataLoaderDataPipe(TestCase):
2936
    r"""
2937
    Verify the behavior of a certain ``DataPipes`` with ``DataLoader``
2938
    """
2939

2940
    def test_shuffler_iterdatapipe(self):
2941
        r"""
2942
        Verify ``IterDataPipe.shuffle`` is controlled by ``DataLoader``
2943
        to generate different seeds deterministically per epoch.
2944
        """
2945
        exp = list(range(100))
2946

2947
        def _create_dp(buffer_size):
2948
            input_ds = dp.iter.IterableWrapper(exp)
2949
            return input_ds.shuffle(buffer_size=buffer_size).sharding_filter()
2950

2951
        for bs in (5, 20, 33):
2952
            # Test Deterministic
2953
            for num_workers, pw in itertools.product((0, 1, 2), (True, False)):
2954
                if num_workers == 0 and pw:
2955
                    continue
2956

2957
                shuffle_dp = _create_dp(bs)
2958

2959
                mp_ctx = "spawn" if num_workers > 0 else None
2960
                dl = DataLoader(
2961
                    shuffle_dp,
2962
                    num_workers=num_workers,
2963
                    shuffle=True,
2964
                    multiprocessing_context=mp_ctx,
2965
                    persistent_workers=pw,
2966
                )
2967

2968
                # No seed
2969
                dl_res_ns = list(dl)
2970
                self.assertEqual(sorted(dl_res_ns), exp)
2971

2972
                # Same seeds
2973
                dl_res = []
2974
                for epoch in range(2):
2975
                    torch.manual_seed(123)
2976
                    dl_res.append(list(dl))
2977
                self.assertEqual(dl_res[0], dl_res[1])
2978
                self.assertEqual(sorted(dl_res[0]), exp)
2979

2980
                # Different seeds
2981
                torch.manual_seed(321)
2982
                dl_res.append(list(dl))
2983

2984
                self.assertEqual(len(dl_res[0]), len(dl_res[2]))
2985
                self.assertNotEqual(dl_res[0], dl_res[2])
2986
                self.assertEqual(sorted(dl_res[0]), sorted(dl_res[2]))
2987

2988
                if dl._iterator is not None:
2989
                    dl._iterator._shutdown_workers()
2990
                    dl._iterator = None
2991
                del dl
2992

2993

2994
class StringDataset(Dataset):
2995
    def __init__(self) -> None:
2996
        self.s = "12345"
2997

2998
    def __len__(self):
2999
        return len(self.s)
3000

3001
    def __getitem__(self, ndx):
3002
        return (self.s[ndx], ndx)
3003

3004

3005
@unittest.skipIf(
3006
    TEST_WITH_TSAN,
3007
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3008
    "fork is not supported. Dying (set die_after_fork=0 to override)",
3009
)
3010
class TestStringDataLoader(TestCase):
3011
    def setUp(self):
3012
        super().setUp()
3013
        self.dataset = StringDataset()
3014

3015
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3016
    def test_shuffle_pin_memory(self):
3017
        loader = DataLoader(
3018
            self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
3019
        )
3020
        for s, n in loader:
3021
            self.assertIsInstance(s[0], str)
3022
            self.assertTrue(n.is_pinned())
3023

3024

3025
class DictDataset(Dataset):
3026
    def __len__(self):
3027
        return 4
3028

3029
    def __getitem__(self, ndx):
3030
        return {
3031
            "a_tensor": torch.empty(4, 2).fill_(ndx),
3032
            "another_dict": {"a_number": ndx},
3033
        }
3034

3035

3036
@unittest.skipIf(
3037
    TEST_WITH_TSAN,
3038
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3039
    "fork is not supported. Dying (set die_after_fork=0 to override)",
3040
)
3041
class TestDictDataLoader(TestCase):
3042
    def setUp(self):
3043
        super().setUp()
3044
        self.dataset = DictDataset()
3045

3046
    def test_sequential_batch(self):
3047
        for persistent_workers in (False, True):
3048
            if persistent_workers:
3049
                loader = DataLoader(
3050
                    self.dataset,
3051
                    batch_size=2,
3052
                    shuffle=False,
3053
                    persistent_workers=persistent_workers,
3054
                    num_workers=1,
3055
                )
3056
            else:
3057
                loader = DataLoader(
3058
                    self.dataset,
3059
                    batch_size=2,
3060
                    shuffle=False,
3061
                    persistent_workers=persistent_workers,
3062
                )
3063
            batch_size = loader.batch_size
3064
            for i, sample in enumerate(loader):
3065
                idx = i * batch_size
3066
                self.assertEqual(set(sample.keys()), {"a_tensor", "another_dict"})
3067
                self.assertEqual(set(sample["another_dict"].keys()), {"a_number"})
3068

3069
                t = sample["a_tensor"]
3070
                self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
3071
                self.assertTrue((t[0] == idx).all())
3072
                self.assertTrue((t[1] == idx + 1).all())
3073

3074
                n = sample["another_dict"]["a_number"]
3075
                self.assertEqual(n.size(), torch.Size([batch_size]))
3076
                self.assertEqual(n[0], idx)
3077
                self.assertEqual(n[1], idx + 1)
3078

3079
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3080
    def test_pin_memory(self):
3081
        loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
3082
        for sample in loader:
3083
            self.assertTrue(sample["a_tensor"].is_pinned())
3084
            self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
3085

3086
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3087
    def test_pin_memory_device(self):
3088
        loader = DataLoader(
3089
            self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda"
3090
        )
3091
        for sample in loader:
3092
            self.assertTrue(sample["a_tensor"].is_pinned(device="cuda"))
3093
            self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda"))
3094

3095
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3096
    def test_pin_memory_with_only_device(self):
3097
        loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda")
3098
        for sample in loader:
3099
            self.assertFalse(sample["a_tensor"].is_pinned(device="cuda"))
3100
            self.assertFalse(
3101
                sample["another_dict"]["a_number"].is_pinned(device="cuda")
3102
            )
3103

3104

3105
class DummyDataset(torch.utils.data.Dataset):
3106
    def __init__(self) -> None:
3107
        self.data = list(range(10))
3108

3109
    def __len__(self):
3110
        return len(self.data)
3111

3112
    def __getitem__(self, idx):
3113
        if torch.is_tensor(idx):
3114
            idx = idx.tolist()
3115
        # The persistent workers always maintain the original
3116
        # dataset through the dataloader lifetime
3117
        # so the attributes will remain the same as the
3118
        # first time the workers where spawned (dataloader iteration)
3119
        assert self.start == 0
3120
        return self.data[idx]
3121

3122

3123
@unittest.skipIf(
3124
    TEST_WITH_TSAN,
3125
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3126
    "fork is not supported. Dying (set die_after_fork=0 to override)",
3127
)
3128
@unittest.skipIf(
3129
    TEST_WITH_ASAN,
3130
    "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
3131
)
3132
class TestDataLoaderPersistentWorkers(TestDataLoader):
3133
    def setUp(self):
3134
        super().setUp()
3135
        self.persistent_workers = True
3136

3137
    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
3138
    @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
3139
    def test_fd_limit_exceeded(self):
3140
        # See NOTE [ DataLoader on Linux and open files limit ]
3141
        import subprocess
3142

3143
        subprocess.check_output(
3144
            [
3145
                sys.executable,
3146
                "-c",
3147
                """\
3148
import torch
3149
import resource
3150
from torch.utils.data import DataLoader, IterableDataset
3151

3152
class RandomDataset(IterableDataset):
3153
    def __init__(self, len, size):
3154
        super(RandomDataset).__init__()
3155
        self.len = len
3156
        self.size = size
3157

3158
    def __iter__(self):
3159
        return self
3160

3161
    def __next__(self):
3162
        if self.len <= 0:
3163
            raise StopIteration
3164
        self.len -= 1
3165
        return torch.randn(self.size)
3166

3167
try:
3168
    keep_fds_alive = []
3169
    resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
3170
    for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
3171
                               num_workers=1, persistent_workers=True):
3172
      random_t.max(dim=0)
3173
      keep_fds_alive.append(random_t)
3174
except RuntimeError as e:
3175
    assert "ulimit -n" in str(e)
3176
    assert "set_sharing_strategy" in str(e)
3177
""",
3178
            ]
3179
        )
3180

3181
    def test_dataset_not_reset(self):
3182
        dataset = DummyDataset()
3183
        pin_memory_configs = [False]
3184
        if TEST_CUDA:
3185
            pin_memory_configs.append(True)
3186
        for pin_memory in pin_memory_configs:
3187
            dataloader = self._get_data_loader(
3188
                dataset, num_workers=2, pin_memory=pin_memory
3189
            )
3190
            dataset.start = 0
3191
            for i in range(10):
3192
                for x in dataloader:
3193
                    pass
3194
                # Changing the start value here doesn't have any effect in the dataset
3195
                # cached by the workers. since they are not recreated between epochs
3196
                # and can cache values safely
3197
                dataset.start = i
3198

3199
    @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
3200
    @unittest.skipIf(IS_WINDOWS, "Needs fork")
3201
    def test_early_exit(self):
3202
        import subprocess
3203

3204
        proc = subprocess.check_output(
3205
            [
3206
                sys.executable,
3207
                "-c",
3208
                """\
3209
import torch
3210
from torch.utils.data import DataLoader, IterableDataset
3211

3212
class RandomDataset(IterableDataset):
3213
    def __init__(self, len, size):
3214
        super(RandomDataset).__init__()
3215
        self.len = len
3216
        self.size = size
3217

3218
    def __iter__(self):
3219
        return self
3220

3221
    def __next__(self):
3222
        if self.len <= 0:
3223
            raise StopIteration
3224
        self.len -= 1
3225
        return torch.randn(self.size)
3226

3227
if __name__ == '__main__':
3228
    dl = DataLoader(
3229
        RandomDataset(64, (28, 28)),
3230
        batch_size=16,
3231
        num_workers=2,
3232
        pin_memory=True,
3233
        persistent_workers=True,
3234
        multiprocessing_context="fork",
3235
    )
3236

3237
    for _ in dl:
3238
        break
3239
""",
3240
            ]
3241
        )
3242

3243

3244
class NamedTupleDataset(Dataset):
3245
    from collections import namedtuple
3246

3247
    Batch = namedtuple("Batch", ["data", "label", "random_tensor"])
3248
    Data = namedtuple("Data", ["positive", "negative"])
3249

3250
    def __len__(self):
3251
        return 4
3252

3253
    def __getitem__(self, ndx):
3254
        return self.Batch(
3255
            data=self.Data(positive=ndx, negative=-ndx),
3256
            label=str(ndx),
3257
            random_tensor=torch.randn(3),
3258
        )
3259

3260

3261
@unittest.skipIf(
3262
    TEST_WITH_TSAN,
3263
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3264
    "fork is not supported. Dying (set die_after_fork=0 to override)",
3265
)
3266
class TestNamedTupleDataLoader(TestCase):
3267
    def setUp(self):
3268
        super().setUp()
3269
        self.dataset = NamedTupleDataset()
3270

3271
    def test_dataloader_with_namedtuple(self):
3272
        # auto-collation
3273
        loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA)
3274
        for batch in loader:
3275
            self.assertIsInstance(batch, NamedTupleDataset.Batch)
3276
            self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
3277
            self.assertIsInstance(batch.data, NamedTupleDataset.Data)
3278
            self.assertIsInstance(batch.data.positive, torch.Tensor)
3279
            self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA)
3280
        # no auto-collation
3281
        loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA)
3282
        for batch in loader:
3283
            self.assertIsInstance(batch, NamedTupleDataset.Batch)
3284
            self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
3285
            self.assertIsInstance(batch.data, NamedTupleDataset.Data)
3286
            self.assertNotIsInstance(batch.data.positive, torch.Tensor)
3287

3288

3289
class SimpleCustomBatch:
3290
    def __init__(self, data):
3291
        transposed_data = list(zip(*data))
3292
        self.inp = torch.stack(transposed_data[0], 0)
3293
        self.tgt = torch.stack(transposed_data[1], 0)
3294

3295
    def pin_memory(self):
3296
        self.inp = self.inp.pin_memory()
3297
        self.tgt = self.tgt.pin_memory()
3298
        return self
3299

3300
    def is_pinned(self):
3301
        return self.inp.is_pinned() and self.tgt.is_pinned()
3302

3303

3304
# Workaround for https://github.com/pytorch/pytorch/issues/50661
3305
# Classes from  `__main__` can not be correctly unpickled from spawned module
3306
# See https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming
3307
self_module = __import__(os.path.splitext(os.path.basename(__file__))[0])
3308

3309

3310
def collate_wrapper(batch):
3311
    return self_module.SimpleCustomBatch(batch)
3312

3313

3314
def collate_into_packed_sequence(batch):
3315
    data = torch.stack([sample[0] for sample in batch], 1)
3316
    t, b = data.size()
3317
    lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
3318
    return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
3319

3320

3321
def collate_into_packed_sequence_batch_first(batch):
3322
    data = torch.stack([sample[0] for sample in batch], 0)
3323
    b, t = data.size()
3324
    lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
3325
    return torch.nn.utils.rnn.pack_padded_sequence(
3326
        data, lengths, batch_first=True, enforce_sorted=False
3327
    )
3328

3329

3330
@unittest.skipIf(
3331
    TEST_WITH_TSAN,
3332
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3333
    "fork is not supported. Dying (set die_after_fork=0 to override)",
3334
)
3335
class TestCustomPinFn(TestCase):
3336
    def setUp(self):
3337
        super().setUp()
3338
        inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
3339
        tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
3340
        self.dataset = TensorDataset(inps, tgts)
3341

3342
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3343
    def test_custom_batch_pin(self):
3344
        test_cases = [
3345
            (collate_wrapper, self_module.SimpleCustomBatch),
3346
            (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
3347
            (
3348
                collate_into_packed_sequence_batch_first,
3349
                torch.nn.utils.rnn.PackedSequence,
3350
            ),
3351
        ]
3352
        for collate_fn, elem_cls in test_cases:
3353
            loader = DataLoader(
3354
                self.dataset, batch_size=2, collate_fn=collate_fn, pin_memory=True
3355
            )
3356
            for sample in loader:
3357
                self.assertIsInstance(sample, elem_cls)
3358
                self.assertTrue(sample.is_pinned())
3359

3360
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3361
    def test_custom_batch_pin_worker(self):
3362
        test_cases = [
3363
            (collate_wrapper, self_module.SimpleCustomBatch),
3364
            (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
3365
            (
3366
                collate_into_packed_sequence_batch_first,
3367
                torch.nn.utils.rnn.PackedSequence,
3368
            ),
3369
        ]
3370
        for collate_fn, elem_cls in test_cases:
3371
            loader = DataLoader(
3372
                self.dataset,
3373
                batch_size=2,
3374
                collate_fn=collate_fn,
3375
                pin_memory=True,
3376
                num_workers=1,
3377
            )
3378
            for sample in loader:
3379
                self.assertIsInstance(sample, elem_cls)
3380
                self.assertTrue(sample.is_pinned())
3381

3382

3383
class TestWorkerQueueDataset(Dataset):
3384
    def __init__(self, data):
3385
        self.data = data
3386
        self.worker_id = None
3387

3388
    def worker_init_fn(self, worker_id):
3389
        self.worker_id = worker_id
3390

3391
    def __getitem__(self, item):
3392
        return self.worker_id, self.data[item]
3393

3394
    def __len__(self):
3395
        return len(self.data)
3396

3397

3398
@unittest.skipIf(
3399
    TEST_WITH_TSAN,
3400
    "Fails with TSAN with the following error: starting new threads after multi-threaded "
3401
    "fork is not supported. Dying (set die_after_fork=0 to override)",
3402
)
3403
@unittest.skipIf(
3404
    TEST_WITH_ASAN,
3405
    "Flaky with ASAN, see https://github.com/pytorch/pytorch/issues/65727",
3406
)
3407
class TestIndividualWorkerQueue(TestCase):
3408
    def setUp(self):
3409
        super().setUp()
3410
        self.dataset = TestWorkerQueueDataset(list(range(128)))
3411

3412
    def _run_ind_worker_queue_test(self, batch_size, num_workers):
3413
        loader = DataLoader(
3414
            self.dataset,
3415
            batch_size=batch_size,
3416
            shuffle=False,
3417
            num_workers=num_workers,
3418
            timeout=5,
3419
            worker_init_fn=self.dataset.worker_init_fn,
3420
        )
3421
        current_worker_idx = 0
3422
        for i, (worker_ids, sample) in enumerate(loader):
3423
            self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
3424
            self.assertEqual(
3425
                sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size))
3426
            )
3427
            current_worker_idx += 1
3428
            if current_worker_idx == num_workers:
3429
                current_worker_idx = 0
3430

3431
    def test_ind_worker_queue(self):
3432
        max_num_workers = None
3433
        if hasattr(os, "sched_getaffinity"):
3434
            try:
3435
                max_num_workers = len(os.sched_getaffinity(0))
3436
            except Exception:
3437
                pass
3438
        if max_num_workers is None:
3439
            cpu_count = os.cpu_count()
3440
            if cpu_count is not None:
3441
                # Use half number of CPUs
3442
                max_num_workers = cpu_count // 2
3443

3444
        if max_num_workers is None:
3445
            max_num_workers = 1
3446

3447
        for batch_size in (8, 16, 32, 64):
3448
            for num_workers in range(0, min(6, max_num_workers)):
3449
                self._run_ind_worker_queue_test(
3450
                    batch_size=batch_size, num_workers=num_workers + 1
3451
                )
3452

3453

3454
class SetAffinityDataset(IterableDataset):
3455
    def __iter__(self):
3456
        torch.randperm(1)
3457
        after = os.sched_getaffinity(0)
3458
        return iter(after)
3459

3460

3461
@unittest.skipIf(
3462
    not hasattr(os, "sched_setaffinity"), "os.sched_setaffinity is not available"
3463
)
3464
class TestSetAffinity(TestCase):
3465
    def test_set_affinity_in_worker_init(self):
3466
        # Query the current affinity mask to avoid setting a disallowed one
3467
        old_affinity = os.sched_getaffinity(0)
3468
        if not old_affinity:
3469
            self.skipTest("No affinity information")
3470
        # Choose any
3471
        expected_affinity = list(old_affinity)[-1]
3472

3473
        def worker_set_affinity(_):
3474
            os.sched_setaffinity(0, [expected_affinity])
3475

3476
        dataset = SetAffinityDataset()
3477

3478
        dataloader = torch.utils.data.DataLoader(
3479
            dataset, num_workers=2, worker_init_fn=worker_set_affinity
3480
        )
3481
        for sample in dataloader:
3482
            self.assertEqual(sample, [expected_affinity])
3483

3484

3485
class ConvDataset(Dataset):
3486
    def __init__(self) -> None:
3487
        self.x = torch.ones(1, 1, 24000)
3488
        # Call convolution on parent process
3489
        self[0]
3490

3491
    def __len__(self):
3492
        return 1
3493

3494
    def __getitem__(self, index):
3495
        return torch.nn.functional.conv1d(self.x, torch.ones(1, 1, 2))
3496

3497

3498
@unittest.skipIf(IS_WINDOWS, "Needs fork")
3499
@unittest.skipIf(
3500
    TEST_WITH_ASAN,
3501
    "This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492",
3502
)
3503
class TestConvAfterFork(TestCase):
3504
    # Tests crash reported in https://github.com/pytorch/pytorch/issues/53565
3505
    def test_conv_after_fork(self):
3506
        loader = DataLoader(ConvDataset(), num_workers=1)
3507
        for x in loader:
3508
            self.assertEqual(x.shape, (1, 1, 1, 23999))
3509

3510

3511
instantiate_device_type_tests(TestDataLoaderDeviceType, globals())
3512

3513

3514
if __name__ == "__main__":
3515
    run_tests()
3516

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.