BasicSR

Форк
0
/
prefetch_dataloader.py 
122 строки · 3.1 Кб
1
import queue as Queue
2
import threading
3
import torch
4
from torch.utils.data import DataLoader
5

6

7
class PrefetchGenerator(threading.Thread):
8
    """A general prefetch generator.
9

10
    Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
11

12
    Args:
13
        generator: Python generator.
14
        num_prefetch_queue (int): Number of prefetch queue.
15
    """
16

17
    def __init__(self, generator, num_prefetch_queue):
18
        threading.Thread.__init__(self)
19
        self.queue = Queue.Queue(num_prefetch_queue)
20
        self.generator = generator
21
        self.daemon = True
22
        self.start()
23

24
    def run(self):
25
        for item in self.generator:
26
            self.queue.put(item)
27
        self.queue.put(None)
28

29
    def __next__(self):
30
        next_item = self.queue.get()
31
        if next_item is None:
32
            raise StopIteration
33
        return next_item
34

35
    def __iter__(self):
36
        return self
37

38

39
class PrefetchDataLoader(DataLoader):
40
    """Prefetch version of dataloader.
41

42
    Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
43

44
    TODO:
45
    Need to test on single gpu and ddp (multi-gpu). There is a known issue in
46
    ddp.
47

48
    Args:
49
        num_prefetch_queue (int): Number of prefetch queue.
50
        kwargs (dict): Other arguments for dataloader.
51
    """
52

53
    def __init__(self, num_prefetch_queue, **kwargs):
54
        self.num_prefetch_queue = num_prefetch_queue
55
        super(PrefetchDataLoader, self).__init__(**kwargs)
56

57
    def __iter__(self):
58
        return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
59

60

61
class CPUPrefetcher():
62
    """CPU prefetcher.
63

64
    Args:
65
        loader: Dataloader.
66
    """
67

68
    def __init__(self, loader):
69
        self.ori_loader = loader
70
        self.loader = iter(loader)
71

72
    def next(self):
73
        try:
74
            return next(self.loader)
75
        except StopIteration:
76
            return None
77

78
    def reset(self):
79
        self.loader = iter(self.ori_loader)
80

81

82
class CUDAPrefetcher():
83
    """CUDA prefetcher.
84

85
    Reference: https://github.com/NVIDIA/apex/issues/304#
86

87
    It may consume more GPU memory.
88

89
    Args:
90
        loader: Dataloader.
91
        opt (dict): Options.
92
    """
93

94
    def __init__(self, loader, opt):
95
        self.ori_loader = loader
96
        self.loader = iter(loader)
97
        self.opt = opt
98
        self.stream = torch.cuda.Stream()
99
        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
100
        self.preload()
101

102
    def preload(self):
103
        try:
104
            self.batch = next(self.loader)  # self.batch is a dict
105
        except StopIteration:
106
            self.batch = None
107
            return None
108
        # put tensors to gpu
109
        with torch.cuda.stream(self.stream):
110
            for k, v in self.batch.items():
111
                if torch.is_tensor(v):
112
                    self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
113

114
    def next(self):
115
        torch.cuda.current_stream().wait_stream(self.stream)
116
        batch = self.batch
117
        self.preload()
118
        return batch
119

120
    def reset(self):
121
        self.loader = iter(self.ori_loader)
122
        self.preload()
123

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.