pytorch

Форк
0
/
queue_util.py 
136 строк · 4.4 Кб
1
## @package queue_util
2
# Module caffe2.python.queue_util
3

4

5

6

7

8
from caffe2.python import core, dataio
9
from caffe2.python.task import TaskGroup
10

11
import logging
12

13

14
logger = logging.getLogger(__name__)
15

16

17
class _QueueReader(dataio.Reader):
18
    def __init__(self, wrapper, num_dequeue_records=1):
19
        assert wrapper.schema is not None, (
20
            'Queue needs a schema in order to be read from.')
21
        dataio.Reader.__init__(self, wrapper.schema())
22
        self._wrapper = wrapper
23
        self._num_dequeue_records = num_dequeue_records
24

25
    def setup_ex(self, init_net, exit_net):
26
        exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
27

28
    def read_ex(self, local_init_net, local_finish_net):
29
        self._wrapper._new_reader(local_init_net)
30
        dequeue_net = core.Net('dequeue')
31
        fields, status_blob = dequeue(
32
            dequeue_net,
33
            self._wrapper.queue(),
34
            len(self.schema().field_names()),
35
            field_names=self.schema().field_names(),
36
            num_records=self._num_dequeue_records)
37
        return [dequeue_net], status_blob, fields
38

39
    def read(self, net):
40
        net, _, fields = self.read_ex(net, None)
41
        return net, fields
42

43

44
class _QueueWriter(dataio.Writer):
45
    def __init__(self, wrapper):
46
        self._wrapper = wrapper
47

48
    def setup_ex(self, init_net, exit_net):
49
        exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
50

51
    def write_ex(self, fields, local_init_net, local_finish_net, status):
52
        self._wrapper._new_writer(self.schema(), local_init_net)
53
        enqueue_net = core.Net('enqueue')
54
        enqueue(enqueue_net, self._wrapper.queue(), fields, status)
55
        return [enqueue_net]
56

57

58
class QueueWrapper(dataio.Pipe):
59
    def __init__(self, handler, schema=None, num_dequeue_records=1):
60
        dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
61
        self._queue = handler
62
        self._num_dequeue_records = num_dequeue_records
63

64
    def reader(self):
65
        return _QueueReader(
66
            self, num_dequeue_records=self._num_dequeue_records)
67

68
    def writer(self):
69
        return _QueueWriter(self)
70

71
    def queue(self):
72
        return self._queue
73

74

75
class Queue(QueueWrapper):
76
    def __init__(self, capacity, schema=None, name='queue',
77
                 num_dequeue_records=1):
78
        # find a unique blob name for the queue
79
        net = core.Net(name)
80
        queue_blob = net.AddExternalInput(net.NextName('handler'))
81
        QueueWrapper.__init__(
82
            self, queue_blob, schema, num_dequeue_records=num_dequeue_records)
83
        self.capacity = capacity
84
        self._setup_done = False
85

86
    def setup(self, global_init_net):
87
        assert self._schema, 'This queue does not have a schema.'
88
        self._setup_done = True
89
        global_init_net.CreateBlobsQueue(
90
            [],
91
            [self._queue],
92
            capacity=self.capacity,
93
            num_blobs=len(self._schema.field_names()),
94
            field_names=self._schema.field_names())
95

96

97
def enqueue(net, queue, data_blobs, status=None):
98
    if status is None:
99
        status = net.NextName('status')
100
    # Enqueueing moved the data into the queue;
101
    # duplication will result in data corruption
102
    queue_blobs = []
103
    for blob in data_blobs:
104
        if blob not in queue_blobs:
105
            queue_blobs.append(blob)
106
        else:
107
            logger.warning("Need to copy blob {} to enqueue".format(blob))
108
            queue_blobs.append(net.Copy(blob))
109
    results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status])
110
    return results[-1]
111

112

113
def dequeue(net, queue, num_blobs, status=None, field_names=None,
114
            num_records=1):
115
    if field_names is not None:
116
        assert len(field_names) == num_blobs
117
        data_names = [net.NextName(name) for name in field_names]
118
    else:
119
        data_names = [net.NextName('data', i) for i in range(num_blobs)]
120
    if status is None:
121
        status = net.NextName('status')
122
    results = net.SafeDequeueBlobs(
123
        queue, data_names + [status], num_records=num_records)
124
    results = list(results)
125
    status_blob = results.pop(-1)
126
    return results, status_blob
127

128

129
def close_queue(step, *queues):
130
    close_net = core.Net("close_queue_net")
131
    for queue in queues:
132
        close_net.CloseBlobsQueue([queue], 0)
133
    close_step = core.execution_step("%s_step" % str(close_net), close_net)
134
    return core.execution_step(
135
        "%s_wraper_step" % str(close_net),
136
        [step, close_step])
137

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.