20
import torch.utils.data.datapipes as dp
21
from torch import multiprocessing as mp
22
from torch._utils import ExceptionWrapper
23
from torch.testing._internal.common_device_type import instantiate_device_type_tests
24
from torch.testing._internal.common_utils import (
31
NO_MULTIPROCESSING_SPAWN,
44
from torch.utils.data import (
56
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
57
from torch.utils.data.datapipes.iter import IterableWrapper
58
from torch.utils.data.dataset import random_split
65
except ModuleNotFoundError:
69
"psutil not found. Some critical data loader tests relying on it "
70
"(e.g., TestDataLoader.test_proper_exit) will not run."
73
raise ModuleNotFoundError(err_msg) from None
75
warnings.warn(err_msg)
82
except ModuleNotFoundError:
85
skipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy")
89
load_tests = load_tests
92
torch.cuda.is_available()
93
and sys.platform != "darwin"
94
and sys.platform != "win32"
96
and not TEST_WITH_ROCM
99
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
101
if not NO_MULTIPROCESSING_SPAWN:
115
mp = mp.get_context(method="spawn")
129
supported_multiprocessing_contexts = [None] + list(
130
torch.multiprocessing.get_all_start_methods()
135
def _clone_collate(b):
136
return [x.clone() for x in b]
141
"Fails with TSAN with the following error: starting new threads after multi-threaded "
142
"fork is not supported. Dying (set die_after_fork=0 to override)",
144
class TestDatasetRandomSplit(TestCase):
145
def test_lengths_must_equal_dataset_size(self):
146
with self.assertRaises(ValueError):
147
random_split([1, 2, 3, 4], [1, 2])
149
def test_splits_have_correct_size(self):
150
splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
151
self.assertEqual(len(splits), 2)
152
self.assertEqual(len(splits[0]), 2)
153
self.assertEqual(len(splits[1]), 4)
155
splits = random_split([1, 2, 3, 4, 5, 6], [0.5, 0.5])
156
self.assertEqual(len(splits), 2)
157
self.assertEqual(len(splits[0]), 3)
158
self.assertEqual(len(splits[1]), 3)
164
range(3), [0.5, 0.5], generator=torch.Generator().manual_seed(1)
171
splits = random_split(
172
range(106), [0.1, 0.2, 0.3, 0.4], generator=torch.Generator().manual_seed(1)
174
self.assertEqual(len(splits[0]), 11)
175
self.assertEqual(len(splits[1]), 22)
176
self.assertEqual(len(splits[2]), 31)
177
self.assertEqual(len(splits[3]), 42)
179
def test_splits_are_mutually_exclusive(self):
180
data = [5, 2, 3, 4, 1, 6]
181
splits = random_split(data, [2, 4])
183
all_values.extend(list(splits[0]))
184
all_values.extend(list(splits[1]))
187
self.assertListEqual(data, all_values)
189
splits = random_split(data, [0.33, 0.67])
191
all_values.extend(list(splits[0]))
192
all_values.extend(list(splits[1]))
195
self.assertListEqual(data, all_values)
198
splits = random_split(data, [0.25, 0.75])
200
all_values.extend(list(splits[0]))
201
all_values.extend(list(splits[1]))
204
self.assertListEqual(data, all_values)
206
def test_splits_indexing_type(self):
207
r"""Indices generated by random_split
208
should be of integer type
212
def __init__(self, test_object, custom_list):
213
self.data = custom_list
214
self.test_object = test_object
216
def __getitem__(self, key):
217
self.test_object.assertEqual(type(key), int)
218
return self.data[key]
221
return len(self.data)
224
dataset = CustomDataset(self, x)
225
dataset = random_split(dataset, [5])[0]
226
data_loader = DataLoader(dataset)
227
for batch in data_loader:
231
dataset = CustomDataset(self, x)
232
dataset = random_split(dataset, [1.0])[0]
233
data_loader = DataLoader(dataset)
234
for batch in data_loader:
237
def test_splits_reproducibility(self):
241
for x in random_split(
242
range(10), [3, 7], generator=torch.Generator().manual_seed(1)
245
[[5, 6, 1], [2, 0, 8, 9, 3, 7, 4]],
249
range(100), [60, 40], generator=torch.Generator().manual_seed(42)
252
range(100), [60, 40], generator=torch.Generator().manual_seed(42)
257
range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
260
range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)
267
generator=torch.Generator().manual_seed(42),
272
generator=torch.Generator().manual_seed(42),
276
def test_incomplete_fractional_splits(self):
277
with self.assertRaises(ValueError):
279
random_split([1, 2, 3, 4], [0.1])
281
with self.assertRaises(ValueError):
283
random_split([1, 2, 3, 4], [1.1])
285
def test_splits_generator(self):
287
state = torch.get_rng_state()
289
torch.set_rng_state(state)
290
random_split(range(10), [5, 5])
292
self.assertNotEqual(a, b)
295
state = torch.get_rng_state()
297
torch.set_rng_state(state)
298
random_split(range(10), [5, 5], generator=torch.Generator().manual_seed(42))
300
self.assertEqual(a, b)
302
def test_slicing_of_subset_of_dataset(self):
304
dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
305
subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
306
self.assertEqual(subset_of_dataset[:], dataset[:])
307
self.assertEqual(subset_of_dataset[1:2], dataset[1:2])
308
self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2])
310
subset1, subset2 = random_split(dataset, [3, 2])
311
self.assertEqual(subset1[:], dataset[subset1.indices[:]])
312
self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]])
313
self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]])
315
def test_slicing_of_subset_of_subset(self):
317
dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
318
subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
319
subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4])
320
self.assertEqual(subset_of_subset[:], dataset[:])
321
self.assertEqual(subset_of_subset[0:2], dataset[0:2])
322
self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2])
324
subset1, subset2 = random_split(dataset, [4, 1])
325
subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1])
326
idx = [subset1.indices[i] for i in subset_of_subset1.indices]
327
self.assertEqual(subset_of_subset1[:], dataset[idx.copy()])
328
self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]])
329
self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]])
332
class CUDACountingDataset(Dataset):
333
def __init__(self, n):
337
def __getitem__(self, i):
338
return torch.as_tensor(i, device="cuda")
344
class CountingDataset(Dataset):
345
def __init__(self, n):
349
def __getitem__(self, i):
356
class CountingIterableDataset(IterableDataset):
357
def __init__(self, n):
362
return iter(range(self.n))
370
"Fails with TSAN with the following error: starting new threads after multi-threaded "
371
"fork is not supported. Dying (set die_after_fork=0 to override)",
373
class TestTensorDataset(TestCase):
375
source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
376
self.assertEqual(len(source), 15)
378
def test_getitem(self):
379
t = torch.randn(15, 10, 2, 3, 4, 5)
380
l = torch.randn(15, 10)
381
source = TensorDataset(t, l)
383
self.assertEqual(t[i], source[i][0])
384
self.assertEqual(l[i], source[i][1])
386
def test_getitem_1d(self):
389
source = TensorDataset(t, l)
391
self.assertEqual(t[i], source[i][0])
392
self.assertEqual(l[i], source[i][1])
394
def test_single_tensor(self):
395
t = torch.randn(5, 10)
396
source = TensorDataset(t)
397
self.assertEqual(len(source), 5)
399
self.assertEqual(t[i], source[i][0])
401
def test_many_tensors(self):
402
t0 = torch.randn(5, 10, 2, 3, 4, 5)
403
t1 = torch.randn(5, 10)
404
t2 = torch.randn(5, 10, 2, 5)
405
t3 = torch.randn(5, 10, 3, 7)
406
source = TensorDataset(t0, t1, t2, t3)
407
self.assertEqual(len(source), 5)
409
self.assertEqual(t0[i], source[i][0])
410
self.assertEqual(t1[i], source[i][1])
411
self.assertEqual(t2[i], source[i][2])
412
self.assertEqual(t3[i], source[i][3])
417
"Fails with TSAN with the following error: starting new threads after multi-threaded "
418
"fork is not supported. Dying (set die_after_fork=0 to override)",
420
class TestStackDataset(TestCase):
421
def test_empty(self):
422
with self.assertRaisesRegex(
423
ValueError, "At least one dataset should be passed"
427
def test_mixed(self):
428
with self.assertRaisesRegex(ValueError, "Supported either"):
430
TensorDataset(torch.randn(15, 10)), a=TensorDataset(torch.randn(10, 15))
433
def test_size_mismatch(self):
434
with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
436
TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(10, 15))
438
with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"):
440
a=TensorDataset(torch.randn(15, 10)),
441
b=TensorDataset(torch.randn(10, 15)),
445
source = StackDataset(
446
TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(15))
448
self.assertEqual(len(source), 15)
449
source = StackDataset(TensorDataset(torch.randn(15, 10)))
450
self.assertEqual(len(source), 15)
451
source = StackDataset(
452
a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(15))
454
self.assertEqual(len(source), 15)
455
source = StackDataset(a=TensorDataset(torch.randn(15, 10)))
456
self.assertEqual(len(source), 15)
458
def test_single(self):
459
t = TensorDataset(torch.randn(15, 10))
460
source = StackDataset(t)
462
self.assertEqual(t[i], source[i][0])
463
source = StackDataset(a=t)
465
self.assertEqual(t[i], source[i]["a"])
467
def test_getitem(self):
468
t = TensorDataset(torch.randn(15, 10))
469
l = TensorDataset(torch.randn(15, 5, 4))
470
source = StackDataset(t, l)
472
self.assertEqual(t[i], source[i][0])
473
self.assertEqual(l[i], source[i][1])
474
source = StackDataset(a=t, b=l)
476
self.assertEqual(t[i], source[i]["a"])
477
self.assertEqual(l[i], source[i]["b"])
479
def test_getitems(self):
480
class GetItemsDataset(Dataset):
481
def __init__(self) -> None:
482
self.data = torch.randn(4)
484
def __getitem__(self, item):
485
return self.data[item]
487
def __getitems__(self, items):
488
return self.data[items]
493
t = GetItemsDataset()
496
source = StackDataset(t, l)
497
batch = source.__getitems__([0, 1, 2, 3])
499
self.assertEqual(t[i], batch[i][0])
500
self.assertEqual(l[i], batch[i][1])
502
source = StackDataset(t=t, l=l)
503
batch = source.__getitems__([0, 1, 2, 3])
505
self.assertEqual(t[i], batch[i]["t"])
506
self.assertEqual(l[i], batch[i]["l"])
508
def test_getitems_raises_index_error(self):
509
class GetItemsDataset(Dataset):
510
def __init__(self) -> None:
511
self.data = torch.randn(4)
513
def __getitem__(self, item):
514
return self.data[item]
516
def __getitems__(self, items):
517
return self.data[items]
522
t = GetItemsDataset()
525
source = StackDataset(t, l)
527
with self.assertRaises(IndexError):
528
source.__getitems__([0, 4])
530
def test_getitems_value_error(self):
531
class GetItemsDataset(Dataset):
532
def __init__(self) -> None:
533
self.data = torch.randn(4)
535
def __getitem__(self, item):
536
return self.data[item]
538
def __getitems__(self, items):
539
return self.data[items][:-1]
544
t = GetItemsDataset()
547
source = StackDataset(t, l)
549
with self.assertRaisesRegex(
550
ValueError, "Nested dataset's output size mismatch. Expected 4, got 3"
552
source.__getitems__([0, 1, 2, 3])
557
"Fails with TSAN with the following error: starting new threads after multi-threaded "
558
"fork is not supported. Dying (set die_after_fork=0 to override)",
560
class TestConcatDataset(TestCase):
561
def test_concat_two_singletons(self):
562
result = ConcatDataset([[0], [1]])
563
self.assertEqual(2, len(result))
564
self.assertEqual(0, result[0])
565
self.assertEqual(1, result[1])
567
def test_concat_two_non_singletons(self):
568
result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
569
self.assertEqual(10, len(result))
570
self.assertEqual(0, result[0])
571
self.assertEqual(5, result[5])
573
def test_concat_two_non_singletons_with_empty(self):
575
result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]])
576
self.assertEqual(10, len(result))
577
self.assertEqual(0, result[0])
578
self.assertEqual(5, result[5])
580
def test_concat_raises_index_error(self):
581
result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
582
with self.assertRaises(IndexError):
586
def test_add_dataset(self):
587
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
588
d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
589
d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
590
result = d1 + d2 + d3
591
self.assertEqual(21, len(result))
592
self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
593
self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
594
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
596
def test_iterable_dataset_err(self):
597
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
598
it1 = CountingIterableDataset(5)
599
it2 = CountingIterableDataset(10)
601
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
602
ConcatDataset([d1, it2, it1])
604
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
607
with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
608
ConcatDataset([it1, d1])
612
def set_faulthander_if_available(_=None):
613
faulthandler.enable(sys.__stderr__)
617
faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False)
620
set_faulthander_if_available()
624
def print_traces_of_all_threads(pid):
627
os.kill(pid, signal.SIGUSR1)
631
os.kill(pid, signal.SIGSEGV)
640
class ErrorTrackingProcess(mp.Process):
645
def __init__(self, disable_stderr=True, **kwargs):
646
super().__init__(**kwargs)
647
self._pconn, self._cconn = mp.Pipe()
648
self._exception = None
649
self.disable_stderr = disable_stderr
652
set_faulthander_if_available()
653
if self.disable_stderr:
655
with open(os.devnull, "w") as devnull:
656
os.dup2(devnull.fileno(), sys.stderr.fileno())
659
self._cconn.send(None)
661
self._cconn.send(ExceptionWrapper(sys.exc_info()))
664
def print_traces_of_all_threads(self):
667
), "can only use print_traces_of_all_threads if the process is alive"
669
not self.disable_stderr
670
), "do not disable stderr if you use print_traces_of_all_threads"
675
print_traces_of_all_threads(self.pid)
679
if self._pconn.poll():
680
self._exception = self._pconn.recv()
681
if self._exception is None:
684
return self._exception.exc_type(self._exception.exc_msg)
687
def send_signal(self, signum, ignore_ESRCH=False):
689
os.kill(self.pid, signum)
691
if not ignore_ESRCH or e.errno != errno.ESRCH:
695
class ErrorDataset(Dataset):
696
def __init__(self, size):
703
class SegfaultDataset(Dataset):
704
def __init__(self, size):
707
def __getitem__(self, idx):
708
return ctypes.string_at(0)
714
class SleepDataset(Dataset):
715
def __init__(self, size, sleep_sec):
717
self.sleep_sec = sleep_sec
720
def __getitem__(self, idx):
722
time.sleep(self.sleep_sec)
730
class SeedDataset(Dataset):
731
def __init__(self, size):
734
def __getitem__(self, idx):
735
return torch.initial_seed()
741
class WorkerSpecificIterableDataset(IterableDataset):
742
def __init__(self, sizes_for_all_workers):
743
self.sizes_for_all_workers = sizes_for_all_workers
746
worker_info = torch.utils.data.get_worker_info()
747
assert worker_info is not None
748
return iter(range(self.sizes_for_all_workers[worker_info.id]))
751
return sum(self.sizes_for_all_workers)
758
class SynchronizedDataset(Dataset):
759
def __init__(self, size, batch_size, num_workers):
760
assert size >= num_workers * batch_size
761
self.count = mp.Value("i", 0, lock=True)
762
self.barrier = mp.Semaphore(0)
763
self.num_workers = num_workers
767
with self.count.get_lock():
768
self.count.value += 1
769
if self.count.value == self.num_workers:
770
self.barrier.release()
771
self.barrier.acquire()
772
self.barrier.release()
774
def __getitem__(self, idx):
775
raise NotImplementedError
781
class EmptyTensorDataset(torch.utils.data.Dataset):
782
def __init__(self, len):
788
def __getitem__(self, any):
789
return torch.empty(0)
792
class SynchronizedSeedDataset(SynchronizedDataset):
793
def __getitem__(self, idx):
795
return torch.initial_seed()
798
def _test_timeout(persistent_workers):
799
dataset = SleepDataset(10, 3)
800
dataloader = DataLoader(
805
persistent_workers=persistent_workers,
807
_ = next(iter(dataloader))
810
def _test_timeout_pin_memory(persistent_workers):
811
dataset = SleepDataset(10, 3)
812
dataloader = DataLoader(
818
persistent_workers=persistent_workers,
820
_ = next(iter(dataloader))
823
def _test_large_sampler_indices(persistent_workers):
828
dataloader = torch.utils.data.DataLoader(
829
EmptyTensorDataset(10000000),
831
persistent_workers=persistent_workers,
835
it = iter(dataloader)
838
assert x.numel() == 0
839
raise RuntimeError("My Error")
842
def disable_stderr(worker_id):
844
Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
845
from workers. Since worker signal handler prints with low-level write(),
846
this has to be done on OS level via dup.
848
This is used as worker_init_fn for test_segfault.
853
with open(os.devnull, "w") as devnull:
854
os.dup2(devnull.fileno(), sys.stderr.fileno())
858
dataset = SegfaultDataset(10)
859
dataloader = DataLoader(
860
dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr
862
_ = next(iter(dataloader))
865
def _test_no_segfault():
867
num_threads = torch.get_num_threads()
869
torch.set_num_threads(4)
871
torch.set_num_threads(num_threads)
872
mp_ctx = torch.multiprocessing.get_context(method="fork")
873
dataloader = DataLoader(
876
worker_init_fn=disable_stderr,
877
multiprocessing_context=mp_ctx,
879
_ = next(iter(dataloader))
882
class TestProperExitDataset(Dataset):
883
def __init__(self, size, error_event):
885
self.error_event = error_event
890
def __getitem__(self, idx):
891
worker_info = torch.utils.data.get_worker_info()
893
self.error_event is not None
894
and self.error_event.is_set()
895
and worker_info.id == worker_info.num_workers - 1
898
raise RuntimeError("Worker error")
899
return torch.tensor([idx])
902
class TestProperExitIterableDataset(IterableDataset):
903
def __init__(self, size, error_event):
904
self.error_event = error_event
906
self.remaining = size
915
worker_info = torch.utils.data.get_worker_info()
917
self.error_event is not None
918
and self.error_event.is_set()
919
and worker_info.id == worker_info.num_workers - 1
922
raise RuntimeError("Worker error")
924
if self.remaining < 0:
926
return torch.tensor(-1000)
930
def _test_proper_exit(
940
num_workers = 2 if use_workers else 0
942
if exit_method == "worker_error" or exit_method == "worker_kill":
943
assert use_workers is True
945
if exit_method == "worker_error":
946
worker_error_event = mp.Event()
948
worker_error_event = None
950
if is_iterable_dataset:
951
ds = TestProperExitIterableDataset(7, worker_error_event)
953
ds = TestProperExitDataset(12, worker_error_event)
959
num_workers=num_workers,
960
pin_memory=pin_memory,
961
worker_init_fn=set_faulthander_if_available,
962
persistent_workers=persistent_workers,
970
if is_iterable_dataset:
971
assert len(ds) * num_workers > (error_it + 2 + 1)
973
assert len(loader) > (error_it + 2 + 1) * num_workers
975
if is_iterable_dataset:
976
assert len(ds) > error_it + 1
978
assert len(loader) > error_it + 1
982
workers = it._workers
985
psutil_p = psutil.Process(pid)
987
psutil_p.wait(JOIN_TIMEOUT)
988
assert not psutil_p.is_running()
990
for i, _ in enumerate(it):
992
if not hold_iter_reference:
995
loader_setup_event.set()
996
tester_setup_event.wait()
1001
if worker_error_event is not None:
1002
worker_error_event.set()
1005
if exit_method == "loader_error":
1006
raise RuntimeError("Loader error")
1007
elif exit_method == "loader_kill":
1008
kill_pid(os.getpid())
1009
elif exit_method == "worker_kill":
1010
kill_pid(workers[-1].pid)
1012
if not hold_iter_reference:
1020
class TestWorkerInfoDataset(SynchronizedDataset):
1021
def __getitem__(self, idx):
1023
return torch.tensor(self.value)
1028
def _test_worker_info_init_fn(worker_id):
1029
worker_info = torch.utils.data.get_worker_info()
1031
worker_id == worker_info.id
1032
), "worker_init_fn and worker_info should have consistent id"
1034
worker_id < worker_info.num_workers
1035
), "worker_init_fn and worker_info should have valid id"
1037
worker_info.seed == torch.initial_seed()
1038
), "worker_init_fn and worker_info should have consistent seed"
1039
dataset = worker_info.dataset
1041
dataset, TestWorkerInfoDataset
1042
), "worker_info should have correct dataset copy"
1043
assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy"
1046
worker_info.id = 3999
1047
except RuntimeError as e:
1048
assert str(e) == "Cannot assign attributes to WorkerInfo objects"
1051
except RuntimeError as e:
1052
assert str(e) == "Cannot assign attributes to WorkerInfo objects"
1053
for k in ["id", "num_workers", "seed", "dataset"]:
1054
assert f"{k}=" in repr(worker_info)
1055
dataset.value = [worker_id, os.getpid()]
1058
def _test_get_worker_info():
1060
assert torch.utils.data.get_worker_info() is None
1063
dataset = TestWorkerInfoDataset(6, batch_size, num_workers)
1064
dataloader = DataLoader(
1066
batch_size=batch_size,
1067
num_workers=num_workers,
1068
worker_init_fn=_test_worker_info_init_fn,
1070
it = iter(dataloader)
1074
worker_pids = [w.pid for w in it._workers]
1075
data = torch.cat(data, 0)
1079
assert d[1] == worker_pids[d[0]]
1081
assert torch.utils.data.get_worker_info() is None
1083
assert not hasattr(dataset, "value")
1086
except AttributeError:
1088
raise RuntimeError("Expected AttributeError")
1092
def init_fn(worker_id):
1093
torch.manual_seed(12345)
1097
class ErrorIterableDataset(IterableDataset):
1099
raise RuntimeError("Error in __iter__")
1103
def error_worker_init_fn(_):
1104
raise RuntimeError("Error in worker_init_fn")
1107
class BulkLoadingDataset(Dataset):
1108
def __init__(self, length):
1109
self.length = length
1111
def __getitem__(self, indices):
1112
assert isinstance(indices, (list, tuple))
1113
return torch.as_tensor(indices)
1119
class BulkLoadingSampler(torch.utils.data.Sampler):
1120
def __init__(self, dataset, batch_size):
1121
self.dataset = dataset
1122
self.batch_size = batch_size
1125
for x in torch.randperm(len(self.dataset)).split(self.batch_size):
1129
return int(math.ceil(len(self.dataset) / float(self.batch_size)))
1132
class TestMultiEpochDataset(IterableDataset):
1133
def __init__(self, length):
1134
self.length = length
1137
worker_info = torch.utils.data.get_worker_info()
1138
assert worker_info is not None
1139
worker_id = worker_info.id
1140
for idx in range(self.length // worker_info.num_workers):
1147
class CustomList(list):
1151
class CustomDict(dict):
1155
def row_processor(row):
1156
return np.add(row, 1)
1160
return len(row) == 4
1165
"Fails with TSAN with the following error: starting new threads after multi-threaded "
1166
"fork is not supported. Dying (set die_after_fork=0 to override)",
1170
"DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
1172
class TestDataLoader(TestCase):
1175
self.data = torch.randn(100, 2, 3, 5)
1176
self.labels = torch.randperm(50).repeat(2)
1177
self.dataset = TensorDataset(self.data, self.labels)
1178
self.persistent_workers = False
1180
def _get_data_loader(self, dataset, **kwargs):
1181
persistent_workers = kwargs.get("persistent_workers", self.persistent_workers)
1182
if persistent_workers and kwargs.get("num_workers", 0) == 0:
1183
persistent_workers = False
1184
kwargs["persistent_workers"] = persistent_workers
1185
return DataLoader(dataset, **kwargs)
1187
def _test_sequential(self, loader):
1188
batch_size = loader.batch_size
1189
if batch_size is None:
1190
for idx, (sample, target) in enumerate(loader):
1191
self.assertEqual(sample, self.data[idx])
1192
self.assertEqual(target, self.labels[idx])
1193
self.assertEqual(idx, len(self.dataset) - 1)
1195
for i, (sample, target) in enumerate(loader):
1196
idx = i * batch_size
1197
self.assertEqual(sample, self.data[idx : idx + batch_size])
1198
self.assertEqual(target, self.labels[idx : idx + batch_size])
1199
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
1201
def _test_shuffle(self, loader):
1202
found_data = dict.fromkeys(range(self.data.size(0)), 0)
1203
found_labels = dict.fromkeys(range(self.labels.size(0)), 0)
1204
batch_size = loader.batch_size
1205
if batch_size is None:
1206
for i, (batch_samples, batch_targets) in enumerate(loader):
1207
sample, target = (batch_samples, batch_targets)
1208
for data_point_idx, data_point in enumerate(self.data):
1209
if data_point.eq(sample).all():
1210
self.assertFalse(found_data[data_point_idx])
1211
found_data[data_point_idx] += 1
1213
self.assertEqual(target, self.labels[data_point_idx])
1214
found_labels[data_point_idx] += 1
1215
self.assertEqual(sum(found_data.values()), (i + 1))
1216
self.assertEqual(sum(found_labels.values()), (i + 1))
1217
self.assertEqual(i, (len(self.dataset) - 1))
1219
for i, (batch_samples, batch_targets) in enumerate(loader):
1220
for sample, target in zip(batch_samples, batch_targets):
1221
for data_point_idx, data_point in enumerate(self.data):
1222
if data_point.eq(sample).all():
1223
self.assertFalse(found_data[data_point_idx])
1224
found_data[data_point_idx] += 1
1226
self.assertEqual(target, self.labels[data_point_idx])
1227
found_labels[data_point_idx] += 1
1228
self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
1229
self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
1230
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
1232
def _test_error(self, loader):
1238
except NotImplementedError:
1240
except StopIteration:
1242
errors, math.ceil(float(len(loader.dataset)) / loader.batch_size)
1246
def test_error_in_init(self):
1247
for num_workers in [0, 2]:
1248
loader = self._get_data_loader(
1249
ErrorIterableDataset(), num_workers=num_workers
1251
with self.assertRaisesRegex(RuntimeError, "Error in __iter__"):
1254
loader = self._get_data_loader(
1255
self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn
1257
with self.assertRaisesRegex(RuntimeError, "Error in worker_init_fn"):
1260
def test_typing(self):
1261
from typing import List
1265
class SomeDatasetClass(Dataset[List[torch.Tensor]]):
1268
def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]:
1271
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
1272
@unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
1273
def test_fd_limit_exceeded(self):
1277
subprocess.check_output(
1284
from torch.utils.data import DataLoader, IterableDataset
1286
class RandomDataset(IterableDataset):
1287
def __init__(self, len, size):
1288
super(RandomDataset).__init__()
1299
return torch.randn(self.size)
1303
resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
1304
for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
1307
keep_fds_alive.append(random_t)
1308
except RuntimeError as e:
1309
assert "ulimit -n" in str(e)
1310
assert "set_sharing_strategy" in str(e)
1315
def test_invalid_assign_after_init(self):
1316
dl = self._get_data_loader(self.dataset)
1317
for attr in ("batch_size", "sampler", "batch_sampler", "drop_last", "dataset"):
1320
setattr(dl, attr, {})
1322
self.assertRaises(ValueError, fn)
1324
def test_sequential_nonbatch(self):
1325
self._test_sequential(self._get_data_loader(self.dataset, batch_size=None))
1327
def test_sequential_batch(self):
1328
self._test_sequential(self._get_data_loader(self.dataset))
1329
self._test_sequential(self._get_data_loader(self.dataset, batch_size=2))
1331
def test_bulk_loading_nobatch(self):
1334
ds = BulkLoadingDataset(n)
1335
sampler = BulkLoadingSampler(ds, batch_size=4)
1337
for num_workers in [0, 4]:
1338
dl = self._get_data_loader(
1340
num_workers=num_workers,
1343
pin_memory=TEST_CUDA,
1345
self.assertFalse(dl._auto_collation)
1347
self.assertEqual(samples[0].is_pinned(), TEST_CUDA)
1348
self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n)))
1350
def test_growing_dataset(self):
1351
dataset = [torch.ones(4) for _ in range(4)]
1352
dataloader_seq = self._get_data_loader(dataset, shuffle=False)
1353
dataloader_shuffle = self._get_data_loader(dataset, shuffle=True)
1354
dataset.append(torch.ones(4))
1355
self.assertEqual(len(dataloader_seq), 5)
1356
self.assertEqual(len(dataloader_shuffle), 5)
1358
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1359
def test_sequential_pin_memory(self):
1360
loader = self._get_data_loader(self.dataset, batch_size=2, pin_memory=True)
1361
for input, target in loader:
1362
self.assertTrue(input.is_pinned())
1363
self.assertTrue(target.is_pinned())
1365
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1366
def test_multiple_dataloaders(self):
1367
for multiprocessing_context in supported_multiprocessing_contexts:
1368
loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1))
1370
self._get_data_loader(
1373
multiprocessing_context=multiprocessing_context,
1385
def test_segfault(self):
1386
p = ErrorTrackingProcess(target=_test_segfault)
1388
p.join(JOIN_TIMEOUT)
1390
self.assertFalse(p.is_alive())
1391
self.assertNotEqual(p.exitcode, 0)
1393
self.assertIsInstance(p.exception, OSError)
1394
self.assertRegex(str(p.exception), r"access violation reading ")
1396
self.assertIsInstance(p.exception, RuntimeError)
1399
r"DataLoader worker \(pid \d+\) is killed by signal: ",
1409
@unittest.skipIf(IS_WINDOWS, "Needs fork")
1410
def test_no_segfault(self):
1411
p = ErrorTrackingProcess(target=_test_no_segfault)
1413
p.join(JOIN_TIMEOUT)
1415
self.assertFalse(p.is_alive())
1417
self.assertIsInstance(p.exception, RuntimeError)
1420
r"DataLoader worker \(pid \d+\) is killed by signal: ",
1422
self.fail("Segfault occurred in worker process after fork")
1426
def test_timeout(self):
1427
if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
1431
targets = (_test_timeout, _test_timeout_pin_memory)
1433
targets = (_test_timeout,)
1434
for target in targets:
1435
p = ErrorTrackingProcess(target=target, args=(self.persistent_workers,))
1437
p.join(JOIN_TIMEOUT)
1439
self.assertFalse(p.is_alive())
1440
self.assertNotEqual(p.exitcode, 0)
1441
self.assertIsInstance(p.exception, RuntimeError)
1443
str(p.exception), r"DataLoader timed out after \d+ seconds"
1448
def test_large_sampler_indices(self):
1455
p = ErrorTrackingProcess(
1456
target=_test_large_sampler_indices, args=(self.persistent_workers,)
1459
p.join(JOIN_TIMEOUT)
1461
self.assertFalse(p.is_alive())
1462
self.assertNotEqual(p.exitcode, 0)
1463
self.assertIsInstance(p.exception, RuntimeError)
1464
self.assertRegex(str(p.exception), r"My Error")
1468
def test_invalid_ctor_args_combinations(self):
1470
with self.assertRaisesRegex(
1471
ValueError, "num_workers option should be non-negative"
1473
self._get_data_loader(self.dataset, num_workers=-1)
1474
with self.assertRaisesRegex(
1475
ValueError, "timeout option should be non-negative"
1477
self._get_data_loader(self.dataset, timeout=-1)
1480
with self.assertRaisesRegex(
1482
"batch_size=None option disables auto-batching and is mutually exclusive",
1484
self._get_data_loader(self.dataset, batch_size=None, drop_last=True)
1486
valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1]
1487
with self.assertRaisesRegex(
1488
ValueError, r"multi-process loading \(num_workers > 0\), but got"
1490
self._get_data_loader(
1491
self.dataset, num_workers=0, multiprocessing_context=valid_ctx
1493
with self.assertRaisesRegex(
1494
ValueError, "should specify a valid start method in"
1496
self._get_data_loader(
1497
self.dataset, num_workers=1, multiprocessing_context="bad"
1499
with self.assertRaisesRegex(
1500
TypeError, "multiprocessing_context option should be a valid context "
1502
self._get_data_loader(
1503
self.dataset, num_workers=1, multiprocessing_context=object()
1507
sampler = torch.utils.data.SequentialSampler(self.dataset)
1508
batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False)
1509
with self.assertRaisesRegex(
1510
ValueError, "sampler option is mutually exclusive with shuffle"
1512
self._get_data_loader(
1513
self.dataset, batch_size=11, sampler=sampler, shuffle=True
1515
with self.assertRaisesRegex(
1516
ValueError, "sampler option is mutually exclusive with shuffle"
1518
self._get_data_loader(
1519
self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True
1521
with self.assertRaisesRegex(
1522
ValueError, "sampler option is mutually exclusive with shuffle"
1524
self._get_data_loader(
1525
self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3
1527
with self.assertRaisesRegex(
1528
ValueError, "batch_sampler option is mutually exclusive with"
1530
self._get_data_loader(
1531
self.dataset, batch_size=11, batch_sampler=batch_sampler
1533
with self.assertRaisesRegex(
1534
ValueError, "batch_sampler option is mutually exclusive with"
1536
self._get_data_loader(
1537
self.dataset, shuffle=True, batch_sampler=batch_sampler
1539
with self.assertRaisesRegex(
1540
ValueError, "batch_sampler option is mutually exclusive with"
1542
self._get_data_loader(
1543
self.dataset, drop_last=True, batch_sampler=batch_sampler
1545
with self.assertRaisesRegex(
1546
ValueError, "batch_sampler option is mutually exclusive with"
1548
self._get_data_loader(
1549
self.dataset, drop_last=3, batch_sampler=batch_sampler
1553
dataset = CountingIterableDataset(20)
1554
with self.assertRaisesRegex(
1555
ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
1557
self._get_data_loader(dataset, shuffle=True)
1558
with self.assertRaisesRegex(
1559
ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"
1561
self._get_data_loader(dataset, shuffle=3)
1562
with self.assertRaisesRegex(
1563
ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
1565
self._get_data_loader(
1566
dataset, sampler=torch.utils.data.SequentialSampler(dataset)
1568
with self.assertRaisesRegex(
1569
ValueError, "DataLoader with IterableDataset: expected unspecified sampler"
1571
self._get_data_loader(dataset, sampler=3)
1572
with self.assertRaisesRegex(
1574
"DataLoader with IterableDataset: expected unspecified batch_sampler",
1576
self._get_data_loader(
1578
batch_sampler=torch.utils.data.BatchSampler(
1579
torch.utils.data.SequentialSampler(dataset), 3, False
1582
with self.assertRaisesRegex(
1584
"DataLoader with IterableDataset: expected unspecified batch_sampler",
1586
self._get_data_loader(dataset, batch_sampler=3)
1588
def test_builtin_collection_conversion(self):
1589
for coll_ty in (list, tuple):
1590
for num_workers in (0, 1):
1592
dataset = CountingDataset(20)
1595
self._get_data_loader(
1596
dataset, batch_size=None, num_workers=num_workers
1599
self.assertEqual(fetched, coll_ty(range(20)))
1602
self._get_data_loader(
1603
dataset, batch_size=2, num_workers=num_workers
1607
fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
1611
dataset = CountingIterableDataset(20)
1614
self._get_data_loader(
1615
dataset, batch_size=None, num_workers=num_workers
1618
self.assertEqual(fetched, coll_ty(range(20)))
1622
assert num_workers in [0, 1], "invalid test"
1624
self._get_data_loader(
1625
dataset, batch_size=2, num_workers=num_workers
1629
fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))
1632
def test_iterable_style_dataset(self):
1634
dataset = CountingIterableDataset(20)
1635
dataloader = self._get_data_loader(dataset, batch_size=None)
1636
fetched = list(dataloader)
1637
self.assertEqual(len(fetched), 20)
1638
for i, d in enumerate(fetched):
1640
self.assertIsInstance(d, int)
1641
self.assertEqual(d, i)
1643
self.assertEqual(len(dataloader), len(dataset))
1647
sizes_for_all_workers = [0, 4, 20]
1650
operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1653
assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1654
for prefetch_factor in [2, 3, 4]:
1655
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1656
dataloader = self._get_data_loader(
1658
num_workers=num_workers,
1660
worker_init_fn=set_faulthander_if_available,
1661
prefetch_factor=prefetch_factor,
1663
dataloader_iter = iter(dataloader)
1664
fetched = sorted(dataloader_iter)
1665
for a, b in zip(fetched, expected):
1667
self.assertIsInstance(a, int)
1668
self.assertEqual(a, b)
1670
self.assertEqual(len(dataloader), len(dataset))
1673
dataset = CountingIterableDataset(20)
1674
dataloader = self._get_data_loader(
1676
num_workers=num_workers,
1677
worker_init_fn=set_faulthander_if_available,
1678
prefetch_factor=prefetch_factor,
1680
it = iter(dataloader)
1683
lambda: next(it), "Should not warn before accessing len(dataloader)"
1685
self.assertEqual(len(dataloader), len(dataset))
1686
self.assertEqual(len(dataloader), 20)
1687
it = iter(dataloader)
1690
lambda: next(it), "Should not warn before exceeding length"
1693
with self.assertWarnsRegex(
1695
r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this",
1696
msg="Should always warn after exceeding length",
1700
workers = dataloader_iter._workers
1705
w.join(JOIN_TIMEOUT)
1706
self.assertFalse(w.is_alive())
1707
self.assertEqual(w.exitcode, 0)
1713
dataset = CountingIterableDataset(20)
1714
fetched = list(self._get_data_loader(dataset, batch_size=7))
1715
self.assertEqual(len(fetched), 3)
1716
self.assertEqual(fetched[0].tolist(), list(range(7)))
1717
self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
1718
self.assertEqual(fetched[2].tolist(), list(range(14, 20)))
1722
sizes_for_all_workers = [0, 4, 20]
1725
operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1728
assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1729
for prefetch_factor in [2, 3, 4]:
1730
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1734
dataloader = self._get_data_loader(
1736
num_workers=num_workers,
1738
prefetch_factor=prefetch_factor,
1740
dataloader_iter = iter(dataloader)
1741
fetched = list(dataloader_iter)
1742
self.assertEqual(len(fetched), 4)
1743
fetched = {tuple(t.tolist()) for t in fetched}
1749
tuple(range(7, 14)),
1750
tuple(range(14, 20)),
1755
workers = dataloader_iter._workers
1760
w.join(JOIN_TIMEOUT)
1761
self.assertFalse(w.is_alive())
1762
self.assertEqual(w.exitcode, 0)
1767
dataset = CountingIterableDataset(20)
1768
fetched = list(self._get_data_loader(dataset, batch_size=7, drop_last=True))
1769
self.assertEqual(len(fetched), 2)
1770
self.assertEqual(fetched[0].tolist(), list(range(7)))
1771
self.assertEqual(fetched[1].tolist(), list(range(7, 14)))
1775
sizes_for_all_workers = [0, 4, 20]
1778
operator.iadd, (list(range(s)) for s in sizes_for_all_workers), []
1781
assert len(sizes_for_all_workers) == num_workers, "invalid test case"
1782
for prefetch_factor in [2, 3, 4]:
1783
dataset = WorkerSpecificIterableDataset(sizes_for_all_workers)
1787
dataloader = self._get_data_loader(
1789
num_workers=num_workers,
1792
worker_init_fn=set_faulthander_if_available,
1793
prefetch_factor=prefetch_factor,
1795
dataloader_iter = iter(dataloader)
1796
fetched = list(dataloader_iter)
1797
self.assertEqual(len(fetched), 2)
1798
fetched = {tuple(t.tolist()) for t in fetched}
1799
self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
1802
workers = dataloader_iter._workers
1807
w.join(JOIN_TIMEOUT)
1808
self.assertFalse(w.is_alive())
1809
self.assertEqual(w.exitcode, 0)
1814
def test_chain_iterable_style_dataset(self):
1816
dataset1 = CountingIterableDataset(20)
1817
dataset2 = CountingIterableDataset(15)
1818
expected = list(range(20)) + list(range(15))
1819
for num_workers in [0, 1]:
1820
for chained_dataset in [
1821
dataset1 + dataset2,
1822
ChainDataset([dataset1, dataset2]),
1825
self._get_data_loader(chained_dataset, num_workers=num_workers)
1827
self.assertEqual(len(fetched), len(expected))
1828
for e, d in zip(expected, fetched):
1829
self.assertIsInstance(d, torch.Tensor)
1830
self.assertEqual(e, d)
1832
with self.assertRaisesRegex(
1833
AssertionError, "ChainDataset only supports IterableDataset"
1835
list(iter(dataset1 + self.dataset))
1837
with self.assertRaisesRegex(
1838
AssertionError, "ChainDataset only supports IterableDataset"
1840
list(iter(ChainDataset([dataset1, self.dataset])))
1842
@unittest.skipIf(IS_MACOS, "Not working on macos")
1843
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1845
def test_multiprocessing_contexts(self):
1850
torch.arange(9, 11),
1853
dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA))
1854
for ctx in supported_multiprocessing_contexts:
1857
ctx in ["spawn", "forkserver"]
1862
ds_cls = CUDACountingDataset
1864
ds_cls = CountingDataset
1868
self._get_data_loader(
1869
ds_cls(counting_ds_n),
1870
multiprocessing_context=ctx,
1877
ctx = mp.get_context(ctx)
1881
self._get_data_loader(
1882
ds_cls(counting_ds_n),
1883
multiprocessing_context=ctx,
1889
def _test_multiprocessing_iterdatapipe(self, with_dill):
1894
torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
1895
torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64),
1897
datapipe: IterDataPipe = IterableWrapper([[1, 2, 3, 4], [1, 2, 3, 4, 5, 6]])
1898
datapipe = datapipe.map(row_processor)
1900
datapipe.filter(lambda row: len(row) == 4)
1902
else datapipe.filter(filter_len)
1905
dl_common_args = dict(
1906
num_workers=2, batch_size=2, shuffle=True, pin_memory=(not TEST_CUDA)
1908
for ctx in supported_multiprocessing_contexts:
1913
for t in self._get_data_loader(
1914
datapipe, multiprocessing_context=ctx, **dl_common_args
1920
ctx = mp.get_context(ctx)
1925
for t in self._get_data_loader(
1926
datapipe, multiprocessing_context=ctx, **dl_common_args
1932
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1933
def test_multiprocessing_iterdatapipe(self):
1934
self._test_multiprocessing_iterdatapipe(with_dill=False)
1936
@unittest.expectedFailure
1938
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
1940
def test_multiprocessing_iterdatapipe_with_dill(self):
1941
self._test_multiprocessing_iterdatapipe(with_dill=True)
1943
def test_worker_seed(self):
1946
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
1947
dataloader = self._get_data_loader(
1948
dataset, batch_size=batch_size, num_workers=num_workers
1951
seeds.update(batch[0] for batch in dataloader)
1952
self.assertEqual(len(seeds), num_workers)
1954
def test_worker_seed_reproducibility(self):
1955
def get_dataloader():
1958
batch_size=batch_size,
1959
num_workers=num_workers,
1960
generator=torch.Generator().manual_seed(42),
1965
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
1967
{int(batch) for batch in get_dataloader()},
1968
{int(batch) for batch in get_dataloader()},
1971
def test_multi_epochs_reproducibility(self):
1976
dataset = TestMultiEpochDataset(batch_size * num_workers)
1977
dataloader = self._get_data_loader(
1978
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
1981
for ind in range(num_epochs):
1982
for batch_idx, sample in enumerate(dataloader):
1984
sample.tolist(), [batch_idx % num_workers] * batch_size
1987
def test_worker_init_fn(self):
1988
dataset = SeedDataset(4)
1989
dataloader = self._get_data_loader(
1990
dataset, batch_size=2, num_workers=2, worker_init_fn=init_fn
1992
for batch in dataloader:
1993
self.assertEqual(12345, batch[0])
1994
self.assertEqual(12345, batch[1])
1996
def test_get_worker_info(self):
1997
p = ErrorTrackingProcess(target=_test_get_worker_info)
1999
p.join(JOIN_TIMEOUT)
2001
self.assertFalse(p.is_alive())
2002
self.assertEqual(p.exitcode, 0)
2006
def test_shuffle(self):
2007
self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True))
2009
def test_shuffle_batch_none(self):
2010
self._test_shuffle(DataLoader(self.dataset, batch_size=None, shuffle=True))
2012
def test_shuffle_batch(self):
2014
self._get_data_loader(self.dataset, batch_size=2, shuffle=True)
2017
def test_shuffle_reproducibility(self):
2023
generator=torch.Generator().manual_seed(42),
2029
generator=torch.Generator().manual_seed(42),
2032
self.assertEqual(list(fn()), list(fn()))
2034
def test_sequential_workers(self):
2035
self._test_sequential(self._get_data_loader(self.dataset, num_workers=4))
2037
def test_seqential_batch_workers(self):
2038
self._test_sequential(
2039
self._get_data_loader(self.dataset, batch_size=2, num_workers=4)
2042
def test_seqential_batch_workers_prefetch(self):
2043
self._test_sequential(
2044
DataLoader(self.dataset, batch_size=2, num_workers=4, prefetch_factor=3)
2047
def test_shuffle_workers(self):
2049
self._get_data_loader(self.dataset, shuffle=True, num_workers=4)
2052
def test_shuffle_batch_workers(self):
2054
self._get_data_loader(
2055
self.dataset, batch_size=2, shuffle=True, num_workers=4
2059
def test_shuffle_batch_workers_prefetch(self):
2070
def test_random_sampler(self):
2071
from collections import Counter
2073
from torch.utils.data import RandomSampler
2075
def sample_stat(sampler, num_samples):
2076
counts = Counter(sampler)
2077
count_repeated = sum(val > 1 for val in counts.values())
2082
sum(counts.values()),
2086
n = len(self.dataset) + 1
2087
sampler_with_replacement = RandomSampler(
2088
self.dataset, replacement=True, num_samples=n
2090
count_repeated, minval, maxval, count_total = sample_stat(
2091
sampler_with_replacement, n
2093
self.assertTrue(count_repeated > 0)
2094
self.assertTrue(minval >= 0)
2095
self.assertTrue(maxval < len(self.dataset))
2096
self.assertTrue(count_total == n)
2099
sampler_without_replacement = RandomSampler(self.dataset)
2100
count_repeated, minval, maxval, count_total = sample_stat(
2101
sampler_without_replacement, len(self.dataset)
2103
self.assertTrue(count_repeated == 0)
2104
self.assertTrue(minval == 0)
2105
self.assertTrue(maxval == len(self.dataset) - 1)
2106
self.assertTrue(count_total == len(self.dataset))
2109
n = len(self.dataset) * 2
2110
sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2111
count_repeated, minval, maxval, count_total = sample_stat(
2112
sampler_without_replacement, len(self.dataset)
2114
self.assertTrue(count_repeated == len(self.dataset))
2115
self.assertTrue(minval == 0)
2116
self.assertTrue(maxval == len(self.dataset) - 1)
2117
self.assertTrue(count_total == n)
2119
n = len(self.dataset) - 1
2120
sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2121
count_repeated, minval, maxval, count_total = sample_stat(
2122
sampler_without_replacement, len(self.dataset)
2124
self.assertTrue(count_repeated == 0)
2125
self.assertTrue(minval >= 0)
2126
self.assertTrue(maxval < len(self.dataset))
2127
self.assertTrue(count_total == n)
2129
n = len(self.dataset) + 1
2130
sampler_without_replacement = RandomSampler(self.dataset, num_samples=n)
2131
count_repeated, minval, maxval, count_total = sample_stat(
2132
sampler_without_replacement, len(self.dataset)
2134
self.assertTrue(count_repeated == 1)
2135
self.assertTrue(minval == 0)
2136
self.assertTrue(maxval == len(self.dataset) - 1)
2137
self.assertTrue(count_total == n)
2140
with self.assertRaisesRegex(
2141
TypeError, "replacement should be a boolean value, but got replacement=0"
2143
RandomSampler(self.dataset, replacement=0)
2145
def test_random_sampler_len_with_replacement(self):
2146
from torch.utils.data import RandomSampler
2149
num_samples = len(self.dataset) + 5
2150
sampler = RandomSampler(self.dataset, replacement=True, num_samples=num_samples)
2152
self.assertEqual(num_samples, len(sampler))
2155
count_num_samples = sum(1 for _ in sampler)
2156
self.assertEqual(num_samples, count_num_samples)
2160
count_num_samples_in_data_loader = len(
2161
self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2163
self.assertEqual(num_samples, count_num_samples_in_data_loader)
2167
count_num_samples_in_data_loader = len(
2168
self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2171
int(math.ceil(float(num_samples) / batch_size)),
2172
count_num_samples_in_data_loader,
2175
def test_random_sampler_len_without_replacement(self):
2176
from torch.utils.data import RandomSampler
2179
num_samples = len(self.dataset) + 5
2180
sampler = RandomSampler(
2181
self.dataset, replacement=False, num_samples=num_samples
2184
self.assertEqual(num_samples, len(sampler))
2187
count_num_samples = sum(1 for _ in sampler)
2188
self.assertEqual(num_samples, count_num_samples)
2192
count_num_samples_in_data_loader = len(
2193
self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2195
self.assertEqual(num_samples, count_num_samples_in_data_loader)
2199
count_num_samples_in_data_loader = len(
2200
self._get_data_loader(self.dataset, batch_size=batch_size, sampler=sampler)
2203
num_samples // batch_size + (num_samples % batch_size > 0),
2204
count_num_samples_in_data_loader,
2207
def test_distributed_sampler_invalid_rank(self):
2208
from torch.utils.data.distributed import DistributedSampler
2210
dataset = torch.IntTensor(range(10))
2211
with self.assertRaisesRegex(ValueError, "Invalid rank"):
2212
sampler = DistributedSampler(dataset, 3, 3)
2214
with self.assertRaisesRegex(ValueError, "Invalid rank"):
2215
sampler = DistributedSampler(dataset, 3, -1)
2217
def test_duplicating_data_with_drop_last(self):
2218
from torch.utils.data.distributed import DistributedSampler
2222
data_set = torch.IntTensor(range(num_batches))
2223
scanned_data = torch.IntTensor([])
2224
for i in range(num_processes):
2225
s = DistributedSampler(data_set, num_processes, i)
2226
d_loader = self._get_data_loader(
2228
batch_size=int(num_batches / num_processes),
2232
for data in d_loader:
2233
scanned_data = torch.cat((scanned_data, data), 0)
2235
self.assertEqual(scanned_data.size(), scanned_data.unique().size())
2237
def test_sampler_reproducibility(self):
2238
from torch.utils.data import (
2240
SubsetRandomSampler,
2241
WeightedRandomSampler,
2244
weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6]
2246
lambda: RandomSampler(
2250
generator=torch.Generator().manual_seed(42),
2252
lambda: RandomSampler(
2255
generator=torch.Generator().manual_seed(42),
2257
lambda: WeightedRandomSampler(
2261
generator=torch.Generator().manual_seed(42),
2263
lambda: WeightedRandomSampler(
2267
generator=torch.Generator().manual_seed(42),
2269
lambda: SubsetRandomSampler(
2270
range(10), generator=torch.Generator().manual_seed(42)
2273
self.assertEqual(list(fn()), list(fn()))
2276
RandomSampler(self.dataset, num_samples=5, replacement=True),
2277
RandomSampler(self.dataset, replacement=False),
2278
WeightedRandomSampler(weights, num_samples=5, replacement=True),
2279
WeightedRandomSampler(weights, num_samples=5, replacement=False),
2280
SubsetRandomSampler(range(10)),
2282
torch.manual_seed(0)
2283
l1 = list(sampler) + list(sampler)
2285
torch.manual_seed(0)
2286
l2 = list(sampler) + list(sampler)
2287
self.assertEqual(l1, l2)
2289
its = (iter(sampler), iter(sampler))
2291
for idx in range(len(sampler)):
2294
torch.manual_seed(0)
2295
ls[i].append(next(its[i]))
2296
self.assertEqual(ls[0], ls[1])
2298
def _test_sampler(self, **kwargs):
2299
indices = range(2, 12)
2300
dl = self._get_data_loader(
2301
self.dataset, sampler=indices, batch_size=2, **kwargs
2303
self.assertEqual(len(dl), 5)
2304
for i, (input, _target) in enumerate(dl):
2305
self.assertEqual(len(input), 2)
2306
self.assertEqual(input, self.data[i * 2 + 2 : i * 2 + 4])
2308
def test_sampler(self):
2309
self._test_sampler()
2310
self._test_sampler(num_workers=4)
2311
if not NO_MULTIPROCESSING_SPAWN:
2312
self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
2314
def _test_batch_sampler(self, **kwargs):
2317
for i in range(0, 20, 5):
2318
batches.append(tuple(range(i, i + 2)))
2319
batches.append(tuple(range(i + 2, i + 5)))
2321
dl = self._get_data_loader(self.dataset, batch_sampler=batches, **kwargs)
2322
self.assertEqual(len(dl), 8)
2323
for i, (input, _target) in enumerate(dl):
2326
self.assertEqual(len(input), 2)
2327
self.assertEqual(input, self.data[offset : offset + 2])
2330
self.assertEqual(len(input), 3)
2331
self.assertEqual(input, self.data[offset : offset + 3])
2333
def test_batch_sampler(self):
2334
self._test_batch_sampler()
2335
self._test_batch_sampler(num_workers=4)
2336
if not NO_MULTIPROCESSING_SPAWN:
2337
self._test_batch_sampler(num_workers=4, multiprocessing_context="spawn")
2339
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
2340
def test_shuffle_pin_memory(self):
2341
loader = self._get_data_loader(
2342
self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
2344
for input, target in loader:
2345
self.assertTrue(input.is_pinned())
2346
self.assertTrue(target.is_pinned())
2348
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2349
def test_numpy(self):
2352
class TestDataset(torch.utils.data.Dataset):
2353
def __getitem__(self, i):
2354
return np.ones((2, 3, 4)) * i
2359
loader = self._get_data_loader(TestDataset(), batch_size=12)
2360
batch = next(iter(loader))
2361
self.assertIsInstance(batch, torch.DoubleTensor)
2362
self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
2364
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2365
def test_numpy_gen_state(self):
2366
from torch.utils.data._utils.worker import _generate_state
2373
(4, 13434589827475259383),
2374
(2884386318, 1088094898, 3523808998, 3860348662),
2376
((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)),
2378
(10, 978296274032934101),
2379
(1759791917, 3550927336, 1225977135, 1036538043),
2382
(12, 11868770762134256968),
2383
(3974661794, 3331131333, 3630387033, 2885815368),
2386
(9, 15378787925219019706),
2387
(3815056996, 3162224466, 2735102421, 3190253477),
2389
((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)),
2391
(15, 14617792358407278405),
2392
(3402479508, 1588702753, 1169536393, 3675067356),
2395
(9, 17363320784006640087),
2396
(957989458, 2518334477, 1421725660, 3086155459),
2399
(12, 480002904169484764),
2400
(2732851467, 1762620729, 4055801988, 1277640511),
2403
(15, 16803975943592702950),
2404
(3479415043, 4022359553, 295994005, 3358606349),
2407
(9, 11704776406047813044),
2408
(1968928009, 710113752, 2442656196, 1587420279),
2411
(10, 16357891985431864516),
2412
(1271733898, 4197047399, 3727213786, 2338547348),
2415
(2, 17423369006318065007),
2416
(544294336, 1911284083, 3299147734, 3231058347),
2418
((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)),
2419
((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)),
2421
(6, 6269787272229682235),
2422
(2548857855, 1216457374, 1012973562, 2999759647),
2426
for (worker_id, base_seed), exp in test_cases:
2427
self.assertEqual(exp, _generate_state(base_seed, worker_id))
2429
def test_error(self):
2431
self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True)
2434
def test_error_workers(self):
2436
self._get_data_loader(
2437
ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4
2441
@unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
2442
def test_partial_workers(self):
2443
r"""Check that workers exit even if the iterator is not exhausted."""
2445
pin_memory_configs = (True, False)
2447
pin_memory_configs = (False,)
2449
for pin_memory in pin_memory_configs:
2451
self._get_data_loader(
2452
self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory
2455
workers = loader._workers
2457
pin_memory_thread = loader._pin_memory_thread
2458
for i, _ in enumerate(loader):
2464
w.join(JOIN_TIMEOUT)
2465
self.assertFalse(w.is_alive(), "subprocess not terminated")
2467
pin_memory_thread.join(JOIN_TIMEOUT)
2468
self.assertFalse(pin_memory_thread.is_alive())
2472
@unittest.skipIf(not HAS_PSUTIL, "psutil not found")
2474
def test_proper_exit(self):
2476
r"""There might be ConnectionResetError or leaked semaphore warning """
2477
r"""(due to dirty process exit), but they are all safe to ignore"""
2484
is_iterable_dataset,
2487
hold_iter_reference,
2488
) in itertools.product([True, False], repeat=4):
2496
if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS):
2511
exit_methods = [None, "loader_error", "worker_error", "worker_kill"]
2512
persistent_workers = self.persistent_workers
2514
exit_methods = [None, "loader_error", "loader_kill"]
2515
persistent_workers = False
2517
for exit_method in exit_methods:
2518
if exit_method == "worker_kill":
2523
desc.append(f"is_iterable_dataset={is_iterable_dataset}")
2524
desc.append(f"use_workers={use_workers}")
2525
desc.append(f"pin_memory={pin_memory}")
2526
desc.append(f"hold_iter_reference={hold_iter_reference}")
2527
desc.append(f"exit_method={exit_method}")
2528
desc = "test_proper_exit with " + ", ".join(desc)
2533
loader_setup_event = mp.Event()
2538
tester_setup_event = mp.Event()
2540
loader_p = ErrorTrackingProcess(
2541
target=_test_proper_exit,
2543
is_iterable_dataset,
2547
hold_iter_reference,
2552
disable_stderr=False,
2555
loader_psutil_p = psutil.Process(loader_p.pid)
2559
loader_setup_event.wait(timeout=JOIN_TIMEOUT)
2560
if not loader_setup_event.is_set():
2562
desc + ": loader process failed to setup within given time"
2564
if loader_p.exception is not None:
2565
fail_msg += f", and had exception {loader_p.exception}"
2566
elif not loader_p.is_alive():
2567
fail_msg += f", and exited with code {loader_p.exitcode} but had no exception"
2569
fail_msg += ", and is still alive."
2570
if loader_p.is_alive():
2572
loader_p.print_traces_of_all_threads()
2576
worker_psutil_ps = loader_psutil_p.children()
2579
report_psutil_attrs = [
2595
err_msg = f"{desc}: {reason}"
2596
err_msg += "\nLoader info:\n\t"
2597
if loader_psutil_p.is_running():
2599
loader_psutil_p.as_dict(attrs=report_psutil_attrs)
2602
loader_p.print_traces_of_all_threads()
2604
err_msg += f"exited with code {loader_p.exitcode}"
2606
err_msg += "\nWorker(s) info:"
2607
for idx, worker_psutil_p in enumerate(worker_psutil_ps):
2608
err_msg += f"\n\tWorker {idx}:\n\t\t"
2609
if worker_psutil_p.is_running():
2611
worker_psutil_p.as_dict(attrs=report_psutil_attrs)
2614
print_traces_of_all_threads(worker_psutil_p.pid)
2616
err_msg += "exited with unknown code"
2619
tester_setup_event.set()
2622
loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
2623
if loader_p.is_alive():
2624
fail_reason = "loader process did not terminate"
2625
if loader_p.exception is not None:
2628
+ f", and had exception {loader_p.exception}"
2631
fail(fail_reason + ", and had no exception")
2632
_, alive = psutil.wait_procs(
2634
timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT),
2638
"worker process (pid(s) {}) did not terminate".format(
2639
", ".join(str(p.pid) for p in alive)
2642
if exit_method is None:
2643
if loader_p.exitcode != 0:
2645
f"loader process had nonzero exitcode {loader_p.exitcode}"
2648
if loader_p.exitcode == 0:
2649
fail("loader process had zero exitcode")
2650
if exit_method == "loader_error":
2652
loader_p.exception, RuntimeError
2653
) or "Loader error" not in str(loader_p.exception):
2655
f"loader process did not raise expected exception, but had {loader_p.exception}"
2657
elif exit_method == "worker_kill":
2658
if isinstance(loader_p.exception, RuntimeError):
2659
if "DataLoader worker (pid" not in str(
2663
f"loader process did not raise expected exception, but had {loader_p.exception}"
2665
elif isinstance(loader_p.exception, ConnectionRefusedError):
2676
f"loader process did not raise expected exception, but had {loader_p.exception}"
2678
elif exit_method == "worker_error":
2680
loader_p.exception, RuntimeError
2681
) or "Worker error" not in str(loader_p.exception):
2683
f"loader process did not raise expected exception, but had {loader_p.exception}"
2686
loader_p.terminate()
2689
def check_len(dl, expected):
2690
self.assertEqual(len(dl), expected)
2694
self.assertEqual(n, expected)
2696
check_len(self.dataset, 100)
2697
check_len(self._get_data_loader(self.dataset, batch_size=2), 50)
2698
check_len(self._get_data_loader(self.dataset, batch_size=3), 34)
2700
def test_iterabledataset_len(self):
2701
class IterableDataset(torch.utils.data.IterableDataset):
2706
return iter(range(10))
2708
iterable_loader = DataLoader(IterableDataset(), batch_size=1)
2709
self.assertEqual(len(iterable_loader), 10)
2710
iterable_loader = DataLoader(IterableDataset(), batch_size=1, drop_last=True)
2711
self.assertEqual(len(iterable_loader), 10)
2713
iterable_loader = DataLoader(IterableDataset(), batch_size=2)
2714
self.assertEqual(len(iterable_loader), 5)
2715
iterable_loader = DataLoader(IterableDataset(), batch_size=2, drop_last=True)
2716
self.assertEqual(len(iterable_loader), 5)
2718
iterable_loader = DataLoader(IterableDataset(), batch_size=3)
2719
self.assertEqual(len(iterable_loader), 4)
2720
iterable_loader = DataLoader(IterableDataset(), batch_size=3, drop_last=True)
2721
self.assertEqual(len(iterable_loader), 3)
2723
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2724
def test_numpy_scalars(self):
2727
class ScalarDataset(torch.utils.data.Dataset):
2728
def __init__(self, dtype):
2731
def __getitem__(self, i):
2738
np.float64: torch.DoubleTensor,
2739
np.float32: torch.FloatTensor,
2740
np.float16: torch.HalfTensor,
2741
np.int64: torch.LongTensor,
2742
np.int32: torch.IntTensor,
2743
np.int16: torch.ShortTensor,
2744
np.int8: torch.CharTensor,
2745
np.uint8: torch.ByteTensor,
2747
for dt, tt in dtypes.items():
2748
dset = ScalarDataset(dt)
2749
loader = self._get_data_loader(dset, batch_size=2)
2750
batch = next(iter(loader))
2751
self.assertIsInstance(batch, tt)
2753
def test_default_convert_mapping_keep_type(self):
2754
data = CustomDict({"a": 1, "b": 2})
2755
converted = _utils.collate.default_convert(data)
2757
self.assertEqual(converted, data)
2759
def test_default_convert_sequence_keep_type(self):
2760
data = CustomList([1, 2, 3])
2761
converted = _utils.collate.default_convert(data)
2763
self.assertEqual(converted, data)
2765
def test_default_convert_sequence_dont_keep_type(self):
2767
converted = _utils.collate.default_convert(data)
2769
self.assertEqual(converted, [0, 1])
2771
def test_default_collate_dtype(self):
2773
collated = _utils.collate.default_collate(arr)
2774
self.assertEqual(collated, torch.tensor(arr))
2775
self.assertEqual(collated.dtype, torch.int64)
2777
arr = [1.1, 2.3, -0.9]
2778
collated = _utils.collate.default_collate(arr)
2779
self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64))
2782
collated = _utils.collate.default_collate(arr)
2783
self.assertEqual(collated, torch.tensor(arr))
2784
self.assertEqual(collated.dtype, torch.bool)
2787
arr = ["a", "b", "c"]
2788
self.assertEqual(arr, _utils.collate.default_collate(arr))
2790
def test_default_collate_mapping_keep_type(self):
2791
batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})]
2792
collated = _utils.collate.default_collate(batch)
2794
expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])})
2795
self.assertEqual(collated, expected)
2797
def test_default_collate_sequence_keep_type(self):
2798
batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])]
2799
collated = _utils.collate.default_collate(batch)
2801
expected = CustomList(
2803
torch.tensor([1, 4]),
2804
torch.tensor([2, 5]),
2805
torch.tensor([3, 6]),
2808
self.assertEqual(collated, expected)
2810
def test_default_collate_sequence_dont_keep_type(self):
2811
batch = [range(2), range(2)]
2812
collated = _utils.collate.default_collate(batch)
2814
self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])])
2816
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2817
def test_default_collate_bad_numpy_types(self):
2821
arr = np.array(["a", "b", "c"])
2822
self.assertEqual(arr, _utils.collate.default_collate(arr))
2824
arr = np.array([[["a", "b", "c"]]])
2825
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2827
arr = np.array([object(), object(), object()])
2828
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2830
arr = np.array([[[object(), object(), object()]]])
2831
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
2833
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2834
def test_default_collate_numpy_memmap(self):
2837
with tempfile.TemporaryFile() as f:
2838
arr = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
2839
arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape)
2840
arr_memmap[:] = arr[:]
2841
arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape)
2842
tensor = _utils.collate.default_collate(list(arr_new))
2845
(tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item()
2848
def test_default_collate_bad_sequence_type(self):
2849
batch = [["X"], ["X", "X"]]
2850
self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch))
2852
RuntimeError, lambda: _utils.collate.default_collate(batch[::-1])
2855
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
2856
def test_default_collate_shared_tensor(self):
2859
t_in = torch.zeros(1)
2862
self.assertEqual(t_in.is_shared(), False)
2864
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
2865
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
2870
old = _utils.worker._worker_info
2872
_utils.worker._worker_info = "x"
2873
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
2874
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
2876
_utils.worker._worker_info = old
2878
def test_excessive_thread_creation_warning(self):
2879
with self.assertWarnsRegex(
2881
r"excessive worker creation might get DataLoader running slow or even freeze",
2883
dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
2886
class TestDataLoaderDeviceType(TestCase):
2889
[ctx for ctx in supported_multiprocessing_contexts if ctx is not None],
2891
@unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available")
2892
def test_nested_tensor_multiprocessing(self, device, context):
2894
if "cuda" in device and context == "fork":
2899
torch.nested.nested_tensor([torch.randn(5)], device=device)
2903
pin_memory_settings = [False]
2904
if device == "cpu" and torch.cuda.is_available():
2905
pin_memory_settings.append(True)
2907
for pin_memory in pin_memory_settings:
2908
loader = torch.utils.data.DataLoader(
2912
collate_fn=_clone_collate,
2913
pin_memory=pin_memory,
2914
multiprocessing_context=context,
2917
for i, batch in enumerate(loader):
2918
self.assertEqual(batch[0], dataset[i])
2922
with self.assertRaisesRegex(
2923
RuntimeError, "not currently supported by the default collate_fn"
2925
loader = torch.utils.data.DataLoader(
2929
multiprocessing_context=context,
2935
class IntegrationTestDataLoaderDataPipe(TestCase):
2937
Verify the behavior of a certain ``DataPipes`` with ``DataLoader``
2940
def test_shuffler_iterdatapipe(self):
2942
Verify ``IterDataPipe.shuffle`` is controlled by ``DataLoader``
2943
to generate different seeds deterministically per epoch.
2945
exp = list(range(100))
2947
def _create_dp(buffer_size):
2948
input_ds = dp.iter.IterableWrapper(exp)
2949
return input_ds.shuffle(buffer_size=buffer_size).sharding_filter()
2951
for bs in (5, 20, 33):
2953
for num_workers, pw in itertools.product((0, 1, 2), (True, False)):
2954
if num_workers == 0 and pw:
2957
shuffle_dp = _create_dp(bs)
2959
mp_ctx = "spawn" if num_workers > 0 else None
2962
num_workers=num_workers,
2964
multiprocessing_context=mp_ctx,
2965
persistent_workers=pw,
2969
dl_res_ns = list(dl)
2970
self.assertEqual(sorted(dl_res_ns), exp)
2974
for epoch in range(2):
2975
torch.manual_seed(123)
2976
dl_res.append(list(dl))
2977
self.assertEqual(dl_res[0], dl_res[1])
2978
self.assertEqual(sorted(dl_res[0]), exp)
2981
torch.manual_seed(321)
2982
dl_res.append(list(dl))
2984
self.assertEqual(len(dl_res[0]), len(dl_res[2]))
2985
self.assertNotEqual(dl_res[0], dl_res[2])
2986
self.assertEqual(sorted(dl_res[0]), sorted(dl_res[2]))
2988
if dl._iterator is not None:
2989
dl._iterator._shutdown_workers()
2994
class StringDataset(Dataset):
2995
def __init__(self) -> None:
3001
def __getitem__(self, ndx):
3002
return (self.s[ndx], ndx)
3007
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3008
"fork is not supported. Dying (set die_after_fork=0 to override)",
3010
class TestStringDataLoader(TestCase):
3013
self.dataset = StringDataset()
3015
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3016
def test_shuffle_pin_memory(self):
3017
loader = DataLoader(
3018
self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True
3021
self.assertIsInstance(s[0], str)
3022
self.assertTrue(n.is_pinned())
3025
class DictDataset(Dataset):
3029
def __getitem__(self, ndx):
3031
"a_tensor": torch.empty(4, 2).fill_(ndx),
3032
"another_dict": {"a_number": ndx},
3038
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3039
"fork is not supported. Dying (set die_after_fork=0 to override)",
3041
class TestDictDataLoader(TestCase):
3044
self.dataset = DictDataset()
3046
def test_sequential_batch(self):
3047
for persistent_workers in (False, True):
3048
if persistent_workers:
3049
loader = DataLoader(
3053
persistent_workers=persistent_workers,
3057
loader = DataLoader(
3061
persistent_workers=persistent_workers,
3063
batch_size = loader.batch_size
3064
for i, sample in enumerate(loader):
3065
idx = i * batch_size
3066
self.assertEqual(set(sample.keys()), {"a_tensor", "another_dict"})
3067
self.assertEqual(set(sample["another_dict"].keys()), {"a_number"})
3069
t = sample["a_tensor"]
3070
self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
3071
self.assertTrue((t[0] == idx).all())
3072
self.assertTrue((t[1] == idx + 1).all())
3074
n = sample["another_dict"]["a_number"]
3075
self.assertEqual(n.size(), torch.Size([batch_size]))
3076
self.assertEqual(n[0], idx)
3077
self.assertEqual(n[1], idx + 1)
3079
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3080
def test_pin_memory(self):
3081
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
3082
for sample in loader:
3083
self.assertTrue(sample["a_tensor"].is_pinned())
3084
self.assertTrue(sample["another_dict"]["a_number"].is_pinned())
3086
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3087
def test_pin_memory_device(self):
3088
loader = DataLoader(
3089
self.dataset, batch_size=2, pin_memory=True, pin_memory_device="cuda"
3091
for sample in loader:
3092
self.assertTrue(sample["a_tensor"].is_pinned(device="cuda"))
3093
self.assertTrue(sample["another_dict"]["a_number"].is_pinned(device="cuda"))
3095
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3096
def test_pin_memory_with_only_device(self):
3097
loader = DataLoader(self.dataset, batch_size=2, pin_memory_device="cuda")
3098
for sample in loader:
3099
self.assertFalse(sample["a_tensor"].is_pinned(device="cuda"))
3101
sample["another_dict"]["a_number"].is_pinned(device="cuda")
3105
class DummyDataset(torch.utils.data.Dataset):
3106
def __init__(self) -> None:
3107
self.data = list(range(10))
3110
return len(self.data)
3112
def __getitem__(self, idx):
3113
if torch.is_tensor(idx):
3119
assert self.start == 0
3120
return self.data[idx]
3125
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3126
"fork is not supported. Dying (set die_after_fork=0 to override)",
3130
"DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
3132
class TestDataLoaderPersistentWorkers(TestDataLoader):
3135
self.persistent_workers = True
3137
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
3138
@unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows")
3139
def test_fd_limit_exceeded(self):
3143
subprocess.check_output(
3150
from torch.utils.data import DataLoader, IterableDataset
3152
class RandomDataset(IterableDataset):
3153
def __init__(self, len, size):
3154
super(RandomDataset).__init__()
3165
return torch.randn(self.size)
3169
resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100))
3170
for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork",
3171
num_workers=1, persistent_workers=True):
3173
keep_fds_alive.append(random_t)
3174
except RuntimeError as e:
3175
assert "ulimit -n" in str(e)
3176
assert "set_sharing_strategy" in str(e)
3181
def test_dataset_not_reset(self):
3182
dataset = DummyDataset()
3183
pin_memory_configs = [False]
3185
pin_memory_configs.append(True)
3186
for pin_memory in pin_memory_configs:
3187
dataloader = self._get_data_loader(
3188
dataset, num_workers=2, pin_memory=pin_memory
3192
for x in dataloader:
3199
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
3200
@unittest.skipIf(IS_WINDOWS, "Needs fork")
3201
def test_early_exit(self):
3204
proc = subprocess.check_output(
3210
from torch.utils.data import DataLoader, IterableDataset
3212
class RandomDataset(IterableDataset):
3213
def __init__(self, len, size):
3214
super(RandomDataset).__init__()
3225
return torch.randn(self.size)
3227
if __name__ == '__main__':
3229
RandomDataset(64, (28, 28)),
3233
persistent_workers=True,
3234
multiprocessing_context="fork",
3244
class NamedTupleDataset(Dataset):
3245
from collections import namedtuple
3247
Batch = namedtuple("Batch", ["data", "label", "random_tensor"])
3248
Data = namedtuple("Data", ["positive", "negative"])
3253
def __getitem__(self, ndx):
3255
data=self.Data(positive=ndx, negative=-ndx),
3257
random_tensor=torch.randn(3),
3263
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3264
"fork is not supported. Dying (set die_after_fork=0 to override)",
3266
class TestNamedTupleDataLoader(TestCase):
3269
self.dataset = NamedTupleDataset()
3271
def test_dataloader_with_namedtuple(self):
3273
loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA)
3274
for batch in loader:
3275
self.assertIsInstance(batch, NamedTupleDataset.Batch)
3276
self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
3277
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
3278
self.assertIsInstance(batch.data.positive, torch.Tensor)
3279
self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA)
3281
loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA)
3282
for batch in loader:
3283
self.assertIsInstance(batch, NamedTupleDataset.Batch)
3284
self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA)
3285
self.assertIsInstance(batch.data, NamedTupleDataset.Data)
3286
self.assertNotIsInstance(batch.data.positive, torch.Tensor)
3289
class SimpleCustomBatch:
3290
def __init__(self, data):
3291
transposed_data = list(zip(*data))
3292
self.inp = torch.stack(transposed_data[0], 0)
3293
self.tgt = torch.stack(transposed_data[1], 0)
3295
def pin_memory(self):
3296
self.inp = self.inp.pin_memory()
3297
self.tgt = self.tgt.pin_memory()
3300
def is_pinned(self):
3301
return self.inp.is_pinned() and self.tgt.is_pinned()
3307
self_module = __import__(os.path.splitext(os.path.basename(__file__))[0])
3310
def collate_wrapper(batch):
3311
return self_module.SimpleCustomBatch(batch)
3314
def collate_into_packed_sequence(batch):
3315
data = torch.stack([sample[0] for sample in batch], 1)
3317
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
3318
return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
3321
def collate_into_packed_sequence_batch_first(batch):
3322
data = torch.stack([sample[0] for sample in batch], 0)
3324
lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
3325
return torch.nn.utils.rnn.pack_padded_sequence(
3326
data, lengths, batch_first=True, enforce_sorted=False
3332
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3333
"fork is not supported. Dying (set die_after_fork=0 to override)",
3335
class TestCustomPinFn(TestCase):
3338
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
3339
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
3340
self.dataset = TensorDataset(inps, tgts)
3342
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3343
def test_custom_batch_pin(self):
3345
(collate_wrapper, self_module.SimpleCustomBatch),
3346
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
3348
collate_into_packed_sequence_batch_first,
3349
torch.nn.utils.rnn.PackedSequence,
3352
for collate_fn, elem_cls in test_cases:
3353
loader = DataLoader(
3354
self.dataset, batch_size=2, collate_fn=collate_fn, pin_memory=True
3356
for sample in loader:
3357
self.assertIsInstance(sample, elem_cls)
3358
self.assertTrue(sample.is_pinned())
3360
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
3361
def test_custom_batch_pin_worker(self):
3363
(collate_wrapper, self_module.SimpleCustomBatch),
3364
(collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
3366
collate_into_packed_sequence_batch_first,
3367
torch.nn.utils.rnn.PackedSequence,
3370
for collate_fn, elem_cls in test_cases:
3371
loader = DataLoader(
3374
collate_fn=collate_fn,
3378
for sample in loader:
3379
self.assertIsInstance(sample, elem_cls)
3380
self.assertTrue(sample.is_pinned())
3383
class TestWorkerQueueDataset(Dataset):
3384
def __init__(self, data):
3386
self.worker_id = None
3388
def worker_init_fn(self, worker_id):
3389
self.worker_id = worker_id
3391
def __getitem__(self, item):
3392
return self.worker_id, self.data[item]
3395
return len(self.data)
3400
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3401
"fork is not supported. Dying (set die_after_fork=0 to override)",
3405
"Flaky with ASAN, see https://github.com/pytorch/pytorch/issues/65727",
3407
class TestIndividualWorkerQueue(TestCase):
3410
self.dataset = TestWorkerQueueDataset(list(range(128)))
3412
def _run_ind_worker_queue_test(self, batch_size, num_workers):
3413
loader = DataLoader(
3415
batch_size=batch_size,
3417
num_workers=num_workers,
3419
worker_init_fn=self.dataset.worker_init_fn,
3421
current_worker_idx = 0
3422
for i, (worker_ids, sample) in enumerate(loader):
3423
self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
3425
sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size))
3427
current_worker_idx += 1
3428
if current_worker_idx == num_workers:
3429
current_worker_idx = 0
3431
def test_ind_worker_queue(self):
3432
max_num_workers = None
3433
if hasattr(os, "sched_getaffinity"):
3435
max_num_workers = len(os.sched_getaffinity(0))
3438
if max_num_workers is None:
3439
cpu_count = os.cpu_count()
3440
if cpu_count is not None:
3442
max_num_workers = cpu_count // 2
3444
if max_num_workers is None:
3447
for batch_size in (8, 16, 32, 64):
3448
for num_workers in range(0, min(6, max_num_workers)):
3449
self._run_ind_worker_queue_test(
3450
batch_size=batch_size, num_workers=num_workers + 1
3454
class SetAffinityDataset(IterableDataset):
3457
after = os.sched_getaffinity(0)
3462
not hasattr(os, "sched_setaffinity"), "os.sched_setaffinity is not available"
3464
class TestSetAffinity(TestCase):
3465
def test_set_affinity_in_worker_init(self):
3467
old_affinity = os.sched_getaffinity(0)
3468
if not old_affinity:
3469
self.skipTest("No affinity information")
3471
expected_affinity = list(old_affinity)[-1]
3473
def worker_set_affinity(_):
3474
os.sched_setaffinity(0, [expected_affinity])
3476
dataset = SetAffinityDataset()
3478
dataloader = torch.utils.data.DataLoader(
3479
dataset, num_workers=2, worker_init_fn=worker_set_affinity
3481
for sample in dataloader:
3482
self.assertEqual(sample, [expected_affinity])
3485
class ConvDataset(Dataset):
3486
def __init__(self) -> None:
3487
self.x = torch.ones(1, 1, 24000)
3494
def __getitem__(self, index):
3495
return torch.nn.functional.conv1d(self.x, torch.ones(1, 1, 2))
3498
@unittest.skipIf(IS_WINDOWS, "Needs fork")
3501
"This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492",
3503
class TestConvAfterFork(TestCase):
3505
def test_conv_after_fork(self):
3506
loader = DataLoader(ConvDataset(), num_workers=1)
3508
self.assertEqual(x.shape, (1, 1, 1, 23999))
3511
instantiate_device_type_tests(TestDataLoaderDeviceType, globals())
3514
if __name__ == "__main__":