pytorch

Форк
0
/
data_workers.py 
461 строка · 15.6 Кб
1
## @package data_workers
2
# Module caffe2.python.data_workers
3

4

5

6

7

8

9
'''
10
This module provides a python-land multithreaded data input mechanism
11
for Caffe2 nets.
12

13
Basic usage is as follows:
14
   coordinator = data_workers.init_data_input_workers(
15
      net,
16
      ["data", "label"],
17
      my_fetch_fun,
18
      batch_size=32,
19
      input_source_name="train",
20
      dont_rebatch=False
21
   )
22
   ...
23
   coordinator.start()
24

25
First argument is the Caffe2 net (or model helper), and second argument
26
is list of input blobs that are to be fed.
27

28
Argument 'input_source_name' is used to distinguish different sources of data,
29
such as train or test data. This is to ensure the data does not get mixed up,
30
although two nets would share blobs.
31

32
To do the actual data loading, one defines a "fetcher function"
33
that has call signature
34
   my_fetch_fun(worker_id, batch_size)
35

36
Optionally, one can define a "init function" that is called once before
37
threads start, and has call signature:
38
   my_init_fun(data_coordinator, global_coordinator)
39

40
If dont_rebatch is set to True, the data input is not batched into equal sized
41
chunks but data directly provided by fetchers is used.
42

43
'batch_columns' can be used to specify which dimension is the batch dimension,
44
for each of the inputs. Default is 0 for all iputs.
45

46
'timeout' is the timeout in seconds after which if no data is available, the
47
net will fail (default 600s = 10 mins).
48

49
This function returns a list of numpy arrays corresponding to the different
50
input blobs. In the example above, it would return two arrays, one for the
51
data blob and another for the labels. These arrays can have arbitrary number
52
of elements (i.e they do not need to match the batch size). The batch size
53
is provided for the function as a hint only.
54

55
For example, fetcher function could download images from a remote service or
56
load random images from a directory on a file system.
57

58
For a dummy example, see the data_workers_test unit test.
59

60
Note that for data_parallel_models, init_data_input_workers will be called
61
for each GPU. Note that the 'coordinator' returned by the function is same
62
each time.
63
'''
64

65
import queue as Queue
66
from itertools import chain
67
import logging
68
import threading
69
import numpy as np
70
import time
71

72
from caffe2.python import workspace, core, scope, utils
73
from caffe2.proto import caffe2_pb2
74
from caffe2.python.parallel_workers import Metrics, State, \
75
    WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
76

77
log = logging.getLogger("data_workers")
78
log.setLevel(logging.INFO)
79
LOG_INT_SECS = 60
80

81

82
def get_worker_ids(num_workers):
83
    return list(range(0, num_workers))
84

85

86
def init_data_input_workers(
87
    net,
88
    input_blob_names,
89
    fetch_fun,
90
    batch_size,
91
    num_worker_threads=2,
92
    input_source_name="train",
93
    max_buffered_batches=800,
94
    init_fun=None,
95
    external_loggers=None,
96
    dont_rebatch=False,
97
    batch_columns=None,
98
    timeout=600
99
):
100
    global global_coordinator
101
    device_option = scope.CurrentDeviceScope()
102
    if (device_option is None):
103
        device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
104

105
    metrics = Metrics(external_loggers)
106
    batch_feeder = BatchFeeder(
107
        net,
108
        input_blob_names,
109
        batch_size,
110
        device_option,
111
        scope.CurrentNameScope(),
112
        input_source_name,
113
        global_coordinator.get_queue(input_source_name, max_buffered_batches),
114
        metrics,
115
        dont_rebatch,
116
        batch_columns,
117
        timeout=timeout
118
    )
119

120
    # Launch fetch worker threads
121
    worker_ids = [
122
        global_coordinator.get_new_worker_id()
123
        for i in range(num_worker_threads)
124
    ]
125

126
    # Create coordinator object
127
    coordinator = WorkerCoordinator(
128
        input_source_name, worker_ids, init_fun, batch_feeder)
129

130
    workers = [
131
        threading.Thread(
132
            target=run_worker,
133
            name="data_workers fetcher id {}".format(worker_id),
134
            args=[coordinator,
135
                  DataWorker(coordinator, worker_id, fetch_fun, metrics,
136
                             batch_size, batch_feeder)],
137
        ) for worker_id in worker_ids
138
    ]
139

140
    workers.append(threading.Thread(
141
        target=enqueuer,
142
        name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
143
        args=[coordinator, batch_feeder]))
144
    coordinator._workers = workers
145
    global_coordinator.add(coordinator)
146

147
    return global_coordinator
148

149

150
class BatchFeeder(State):
151
    def __init__(self, net, input_blob_names, batch_size,
152
                 device_option, namescope, input_source_name, queue,
153
                 metrics, dont_rebatch, batch_columns, timeout=600):
154
        self._counter = 0
155
        self._input_blob_names = input_blob_names
156
        self._batch_size = batch_size
157
        self._internal_queue = queue
158
        self._queues = []
159
        self._device_option = device_option
160
        self._namescope = namescope
161
        self._timeout = timeout
162
        self._input_source_name = input_source_name
163
        self._c2_queue_capacity = 4
164
        self._create_caffe2_queues(net)
165
        self._create_caffe2_ops(net)
166
        self._inputs = 0
167
        self._prev_seconds = 0
168
        self._last_warning = time.time()
169
        self._dont_rebatch = dont_rebatch
170
        self._init_scratch()
171
        self._metrics = metrics
172

173
        if batch_columns is None:
174
            batch_columns = [0 for _ in input_blob_names]
175
        self._batch_columns = batch_columns
176

177
    def start(self):
178
        self._inputs = 0
179
        self._prev_seconds = time.time()
180

181
    def stop(self):
182
        try:
183
            for q in self._queues:
184
                workspace.RunOperatorOnce(
185
                    core.CreateOperator("CloseBlobsQueue", [q], [])
186
                )
187
        finally:
188
            self._log_inputs_per_interval(0, force=True)
189

190
    def cleanup(self):
191
        utils.ResetBlobs(self._scratch_blob.values())
192
        utils.ResetBlobs(self._scratch_status.values())
193

194
    def _get(self, data_input_coordinator):
195
        start_time = time.time()
196
        last_warning = time.time()
197
        while data_input_coordinator.is_active():
198
            try:
199
                return self._internal_queue.get(block=True, timeout=0.5)
200
            except Queue.Empty:
201
                if time.time() - last_warning > 10.0:
202
                    log.warning("** Data input is slow: (still) no data in {} secs.".format(
203
                        time.time() - start_time))
204
                    last_warning = time.time()
205
                continue
206
        return None
207

208
    def _validate_chunk(self, chunk):
209
        if chunk is None:
210
            log.warning("Fetcher function returned None")
211
            return False
212

213
        assert len(chunk) == len(self._input_blob_names), \
214
            "Expecting data blob for each input"
215
        for d in chunk:
216
            assert isinstance(d, np.ndarray), \
217
                "Fetcher function must return a numpy array"
218
        if not self._dont_rebatch:
219
            j = 1
220
            for d in chunk[1:]:
221
                assert d.shape[self._batch_columns[j]] == \
222
                    chunk[0].shape[self._batch_columns[0]], \
223
                    "Each returned input must have equal number of samples"
224
                j += 1
225

226
        if len(chunk) == 0:
227
            log.warning("Worker provided zero length input")
228
            return False
229

230
        return True
231

232
    def put(self, chunk, data_input_coordinator):
233
        if not self._validate_chunk(chunk):
234
            return
235

236
        while data_input_coordinator.is_active():
237
            try:
238
                qsize = self._internal_queue.qsize()
239
                if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
240
                    log.warning("Warning, data loading lagging behind: " +
241
                                "queue size={}, name={}".format(qsize, self._input_source_name))
242
                    self._last_warning = time.time()
243
                self._counter += 1
244
                self._internal_queue.put(chunk, block=True, timeout=0.5)
245
                self._log_inputs_per_interval(chunk[0].shape[0])
246
                return
247
            except Queue.Full:
248
                log.debug("Queue full: stalling fetchers...")
249
                continue
250

251
    def _enqueue_batch_direct(self, data_input_coordinator):
252
        data = self._get(data_input_coordinator)
253
        if data is None:
254
            return
255
        if data_input_coordinator.is_active():
256
            for b, q, c in zip(self._input_blob_names, self._queues, data):
257
                self._enqueue(b, q, c)
258

259
    def _enqueue_batch(self, data_input_coordinator):
260
        '''
261
        This pulls data from the python-side queue and collects them
262
        into batch-sized pieces, unless dont_rebatch is set to true.
263
        '''
264
        if self._dont_rebatch:
265
            self._enqueue_batch_direct(data_input_coordinator)
266
            return
267

268
        cur_batch = [np.array([]) for d in self._input_blob_names]
269
        first_batch_col = self._batch_columns[0]
270

271
        # Collect data until we have a full batch size
272
        while (
273
            cur_batch[0].shape[0] == 0 or
274
            cur_batch[0].shape[first_batch_col] < self._batch_size
275
        ) and data_input_coordinator.is_active():
276
            chunk = self._get(data_input_coordinator)
277
            if chunk is None:
278
                continue
279

280
            for j, chunk_elem in enumerate(chunk):
281
                if cur_batch[j].shape[0] == 0:
282
                    cur_batch[j] = chunk_elem.copy()
283
                else:
284
                    cur_batch[j] = np.append(
285
                        cur_batch[j], chunk_elem, axis=self._batch_columns[j]
286
                    )
287

288
        start_time = time.time()
289
        try:
290
            # Return data over the batch size back to queue
291
            if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
292
                first_batch_col
293
            ] > self._batch_size:
294
                leftover = []
295
                trimmed_batch = []
296
                for j, b in enumerate(cur_batch):
297
                    [c, l] = np.split(
298
                        b, [self._batch_size], axis=self._batch_columns[j]
299
                    )
300
                    leftover.append(l)
301
                    trimmed_batch.append(c)
302
                cur_batch = trimmed_batch
303
                try:
304
                    self._internal_queue.put(leftover, block=False)
305
                except Queue.Full:
306
                    pass
307

308
                assert cur_batch[0].shape[first_batch_col] == self._batch_size
309

310
            if data_input_coordinator.is_active():
311
                for b, q, c in zip(
312
                    self._input_blob_names, self._queues, cur_batch
313
                ):
314
                    self._enqueue(b, q, c)
315
        finally:
316
            self._metrics.put_metric('enqueue_time', time.time() - start_time)
317

318
    def _init_scratch(self):
319
        self._scratch_blob = {}
320
        self._scratch_status = {}
321
        for blob_name in self._input_blob_names:
322
            scratch_name = self._namescope + blob_name + \
323
                "_scratch_" + self._input_source_name
324
            self._scratch_blob[blob_name] = core.BlobReference(scratch_name)
325
            self._scratch_status[blob_name] = core.BlobReference(
326
                scratch_name + "_status"
327
            )
328

329
        # Feed empty arrays to the scratch blobs here, so that there won't be
330
        # race conditions when calling FeedBlob (which calls wworkspace
331
        # CreateBlob()) from enqueue threads
332
        for b in chain(
333
            self._scratch_blob.values(), self._scratch_status.values()
334
        ):
335
            workspace.FeedBlob(
336
                b,
337
                np.array([]).astype(np.float32),
338
                device_option=self._device_option,
339
            )
340

341
    def _enqueue(self, blob_name, queue, data_arr):
342
        '''
343
        Enqueue the correctly sized batch arrays to Caffe2's queue.
344
        '''
345
        workspace.FeedBlob(
346
            self._scratch_blob[blob_name],
347
            data_arr,
348
            device_option=self._device_option
349
        )
350

351
        op = core.CreateOperator(
352
            "SafeEnqueueBlobs",
353
            [queue, self._scratch_blob[blob_name]],
354
            [self._scratch_blob[blob_name], self._scratch_status[blob_name]],
355
            device_option=self._device_option
356
        )
357
        workspace.RunOperatorOnce(op)
358

359
    def _create_caffe2_queues(self, net):
360
        '''
361
        Creates queues on caffe2 side
362
        '''
363
        def create_queue(queue_name, num_blobs, capacity):
364
            workspace.RunOperatorOnce(
365
                core.CreateOperator(
366
                    "CreateBlobsQueue",
367
                    [], [queue_name],
368
                    num_blobs=1,
369
                    capacity=capacity))
370
            return core.ScopedBlobReference(queue_name)
371

372
        for blob_name in self._input_blob_names:
373
            qname = blob_name + "_c2queue" + "_" + self._input_source_name
374
            q = create_queue(
375
                qname, num_blobs=1, capacity=self._c2_queue_capacity
376
            )
377
            self._queues.append(q)
378

379
    def _create_caffe2_ops(self, net):
380
        '''
381
        Creates dequeue-ops on caffe2 side
382
        '''
383
        for q, blob_name in zip(self._queues, self._input_blob_names):
384
            # Add operator to the Caffe2 network to dequeue
385
            net.DequeueBlobs(q, blob_name, timeout_secs=float(self._timeout))
386

387
    def _log_inputs_per_interval(self, inputs, force=False):
388
        self._inputs += inputs
389
        current_seconds = time.time()
390
        delta_seconds = current_seconds - self._prev_seconds
391
        if delta_seconds >= LOG_INT_SECS or force:
392
            inputs_per_sec = int(self._inputs / delta_seconds)
393
            qsize = self._internal_queue.qsize()
394
            log.info("{}/{}: {} inputs/sec".format(
395
                self._input_source_name,
396
                self._namescope,
397
                inputs_per_sec,
398
            ))
399
            log.info("-- queue: {} batches".format(qsize))
400
            # log and reset perf metrics
401
            self._metrics.put_metric(
402
                'inputs_per_sec', inputs_per_sec, False)
403
            self._metrics.put_metric('queue_size', qsize, False)
404
            self._metrics.put_metric(
405
                'time_elapsed', delta_seconds, False)
406
            self._metrics.log_metrics()
407
            self._metrics.reset_metrics()
408
            self._inputs = 0
409
            self._prev_seconds = current_seconds
410

411

412
class GlobalCoordinator(GlobalWorkerCoordinator):
413
    def __init__(self):
414
        GlobalWorkerCoordinator.__init__(self)
415
        self._queues = {}
416

417
    def get_queue(self, queue_name, max_buffered_batches):
418
        assert isinstance(max_buffered_batches, int)
419
        if queue_name not in self._queues:
420
            self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
421
        return self._queues[queue_name]
422

423
    def reset_data_input(self, namescope, name, net, batch_size):
424
        log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
425
        for c in self._coordinators:
426
            if c._worker_name == name and c._state._namescope == namescope:
427
                c._state._batch_size = batch_size
428
                c._state._create_caffe2_ops(net)
429

430

431
class DataWorker(Worker):
432
    def __init__(
433
        self,
434
        coordinator,
435
        worker_id,
436
        worker_fun,
437
        metrics,
438
        batch_size,
439
        batch_feeder
440
    ):
441
        Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
442
                        metrics=metrics)
443
        self._batch_size = batch_size
444
        self._batch_feeder = batch_feeder
445

446
    def run(self):
447
        input_data = self._worker_fun(self._worker_id, self._batch_size)
448

449
        self._batch_feeder.put(input_data, self._coordinator)
450

451
    def finish(self):
452
        self._metrics.put_metric(
453
            'fetcher_time', time.time() - self._start_time)
454

455

456
global_coordinator = GlobalCoordinator()
457

458

459
def enqueuer(coordinator, batch_feeder):
460
    while coordinator.is_active():
461
        batch_feeder._enqueue_batch(coordinator)
462

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.