1
## @package data_workers
2
# Module caffe2.python.data_workers
10
This module provides a python-land multithreaded data input mechanism
13
Basic usage is as follows:
14
coordinator = data_workers.init_data_input_workers(
19
input_source_name="train",
25
First argument is the Caffe2 net (or model helper), and second argument
26
is list of input blobs that are to be fed.
28
Argument 'input_source_name' is used to distinguish different sources of data,
29
such as train or test data. This is to ensure the data does not get mixed up,
30
although two nets would share blobs.
32
To do the actual data loading, one defines a "fetcher function"
33
that has call signature
34
my_fetch_fun(worker_id, batch_size)
36
Optionally, one can define a "init function" that is called once before
37
threads start, and has call signature:
38
my_init_fun(data_coordinator, global_coordinator)
40
If dont_rebatch is set to True, the data input is not batched into equal sized
41
chunks but data directly provided by fetchers is used.
43
'batch_columns' can be used to specify which dimension is the batch dimension,
44
for each of the inputs. Default is 0 for all iputs.
46
'timeout' is the timeout in seconds after which if no data is available, the
47
net will fail (default 600s = 10 mins).
49
This function returns a list of numpy arrays corresponding to the different
50
input blobs. In the example above, it would return two arrays, one for the
51
data blob and another for the labels. These arrays can have arbitrary number
52
of elements (i.e they do not need to match the batch size). The batch size
53
is provided for the function as a hint only.
55
For example, fetcher function could download images from a remote service or
56
load random images from a directory on a file system.
58
For a dummy example, see the data_workers_test unit test.
60
Note that for data_parallel_models, init_data_input_workers will be called
61
for each GPU. Note that the 'coordinator' returned by the function is same
66
from itertools import chain
72
from caffe2.python import workspace, core, scope, utils
73
from caffe2.proto import caffe2_pb2
74
from caffe2.python.parallel_workers import Metrics, State, \
75
WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
77
log = logging.getLogger("data_workers")
78
log.setLevel(logging.INFO)
82
def get_worker_ids(num_workers):
83
return list(range(0, num_workers))
86
def init_data_input_workers(
92
input_source_name="train",
93
max_buffered_batches=800,
95
external_loggers=None,
100
global global_coordinator
101
device_option = scope.CurrentDeviceScope()
102
if (device_option is None):
103
device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
105
metrics = Metrics(external_loggers)
106
batch_feeder = BatchFeeder(
111
scope.CurrentNameScope(),
113
global_coordinator.get_queue(input_source_name, max_buffered_batches),
120
# Launch fetch worker threads
122
global_coordinator.get_new_worker_id()
123
for i in range(num_worker_threads)
126
# Create coordinator object
127
coordinator = WorkerCoordinator(
128
input_source_name, worker_ids, init_fun, batch_feeder)
133
name="data_workers fetcher id {}".format(worker_id),
135
DataWorker(coordinator, worker_id, fetch_fun, metrics,
136
batch_size, batch_feeder)],
137
) for worker_id in worker_ids
140
workers.append(threading.Thread(
142
name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
143
args=[coordinator, batch_feeder]))
144
coordinator._workers = workers
145
global_coordinator.add(coordinator)
147
return global_coordinator
150
class BatchFeeder(State):
151
def __init__(self, net, input_blob_names, batch_size,
152
device_option, namescope, input_source_name, queue,
153
metrics, dont_rebatch, batch_columns, timeout=600):
155
self._input_blob_names = input_blob_names
156
self._batch_size = batch_size
157
self._internal_queue = queue
159
self._device_option = device_option
160
self._namescope = namescope
161
self._timeout = timeout
162
self._input_source_name = input_source_name
163
self._c2_queue_capacity = 4
164
self._create_caffe2_queues(net)
165
self._create_caffe2_ops(net)
167
self._prev_seconds = 0
168
self._last_warning = time.time()
169
self._dont_rebatch = dont_rebatch
171
self._metrics = metrics
173
if batch_columns is None:
174
batch_columns = [0 for _ in input_blob_names]
175
self._batch_columns = batch_columns
179
self._prev_seconds = time.time()
183
for q in self._queues:
184
workspace.RunOperatorOnce(
185
core.CreateOperator("CloseBlobsQueue", [q], [])
188
self._log_inputs_per_interval(0, force=True)
191
utils.ResetBlobs(self._scratch_blob.values())
192
utils.ResetBlobs(self._scratch_status.values())
194
def _get(self, data_input_coordinator):
195
start_time = time.time()
196
last_warning = time.time()
197
while data_input_coordinator.is_active():
199
return self._internal_queue.get(block=True, timeout=0.5)
201
if time.time() - last_warning > 10.0:
202
log.warning("** Data input is slow: (still) no data in {} secs.".format(
203
time.time() - start_time))
204
last_warning = time.time()
208
def _validate_chunk(self, chunk):
210
log.warning("Fetcher function returned None")
213
assert len(chunk) == len(self._input_blob_names), \
214
"Expecting data blob for each input"
216
assert isinstance(d, np.ndarray), \
217
"Fetcher function must return a numpy array"
218
if not self._dont_rebatch:
221
assert d.shape[self._batch_columns[j]] == \
222
chunk[0].shape[self._batch_columns[0]], \
223
"Each returned input must have equal number of samples"
227
log.warning("Worker provided zero length input")
232
def put(self, chunk, data_input_coordinator):
233
if not self._validate_chunk(chunk):
236
while data_input_coordinator.is_active():
238
qsize = self._internal_queue.qsize()
239
if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
240
log.warning("Warning, data loading lagging behind: " +
241
"queue size={}, name={}".format(qsize, self._input_source_name))
242
self._last_warning = time.time()
244
self._internal_queue.put(chunk, block=True, timeout=0.5)
245
self._log_inputs_per_interval(chunk[0].shape[0])
248
log.debug("Queue full: stalling fetchers...")
251
def _enqueue_batch_direct(self, data_input_coordinator):
252
data = self._get(data_input_coordinator)
255
if data_input_coordinator.is_active():
256
for b, q, c in zip(self._input_blob_names, self._queues, data):
257
self._enqueue(b, q, c)
259
def _enqueue_batch(self, data_input_coordinator):
261
This pulls data from the python-side queue and collects them
262
into batch-sized pieces, unless dont_rebatch is set to true.
264
if self._dont_rebatch:
265
self._enqueue_batch_direct(data_input_coordinator)
268
cur_batch = [np.array([]) for d in self._input_blob_names]
269
first_batch_col = self._batch_columns[0]
271
# Collect data until we have a full batch size
273
cur_batch[0].shape[0] == 0 or
274
cur_batch[0].shape[first_batch_col] < self._batch_size
275
) and data_input_coordinator.is_active():
276
chunk = self._get(data_input_coordinator)
280
for j, chunk_elem in enumerate(chunk):
281
if cur_batch[j].shape[0] == 0:
282
cur_batch[j] = chunk_elem.copy()
284
cur_batch[j] = np.append(
285
cur_batch[j], chunk_elem, axis=self._batch_columns[j]
288
start_time = time.time()
290
# Return data over the batch size back to queue
291
if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
293
] > self._batch_size:
296
for j, b in enumerate(cur_batch):
298
b, [self._batch_size], axis=self._batch_columns[j]
301
trimmed_batch.append(c)
302
cur_batch = trimmed_batch
304
self._internal_queue.put(leftover, block=False)
308
assert cur_batch[0].shape[first_batch_col] == self._batch_size
310
if data_input_coordinator.is_active():
312
self._input_blob_names, self._queues, cur_batch
314
self._enqueue(b, q, c)
316
self._metrics.put_metric('enqueue_time', time.time() - start_time)
318
def _init_scratch(self):
319
self._scratch_blob = {}
320
self._scratch_status = {}
321
for blob_name in self._input_blob_names:
322
scratch_name = self._namescope + blob_name + \
323
"_scratch_" + self._input_source_name
324
self._scratch_blob[blob_name] = core.BlobReference(scratch_name)
325
self._scratch_status[blob_name] = core.BlobReference(
326
scratch_name + "_status"
329
# Feed empty arrays to the scratch blobs here, so that there won't be
330
# race conditions when calling FeedBlob (which calls wworkspace
331
# CreateBlob()) from enqueue threads
333
self._scratch_blob.values(), self._scratch_status.values()
337
np.array([]).astype(np.float32),
338
device_option=self._device_option,
341
def _enqueue(self, blob_name, queue, data_arr):
343
Enqueue the correctly sized batch arrays to Caffe2's queue.
346
self._scratch_blob[blob_name],
348
device_option=self._device_option
351
op = core.CreateOperator(
353
[queue, self._scratch_blob[blob_name]],
354
[self._scratch_blob[blob_name], self._scratch_status[blob_name]],
355
device_option=self._device_option
357
workspace.RunOperatorOnce(op)
359
def _create_caffe2_queues(self, net):
361
Creates queues on caffe2 side
363
def create_queue(queue_name, num_blobs, capacity):
364
workspace.RunOperatorOnce(
370
return core.ScopedBlobReference(queue_name)
372
for blob_name in self._input_blob_names:
373
qname = blob_name + "_c2queue" + "_" + self._input_source_name
375
qname, num_blobs=1, capacity=self._c2_queue_capacity
377
self._queues.append(q)
379
def _create_caffe2_ops(self, net):
381
Creates dequeue-ops on caffe2 side
383
for q, blob_name in zip(self._queues, self._input_blob_names):
384
# Add operator to the Caffe2 network to dequeue
385
net.DequeueBlobs(q, blob_name, timeout_secs=float(self._timeout))
387
def _log_inputs_per_interval(self, inputs, force=False):
388
self._inputs += inputs
389
current_seconds = time.time()
390
delta_seconds = current_seconds - self._prev_seconds
391
if delta_seconds >= LOG_INT_SECS or force:
392
inputs_per_sec = int(self._inputs / delta_seconds)
393
qsize = self._internal_queue.qsize()
394
log.info("{}/{}: {} inputs/sec".format(
395
self._input_source_name,
399
log.info("-- queue: {} batches".format(qsize))
400
# log and reset perf metrics
401
self._metrics.put_metric(
402
'inputs_per_sec', inputs_per_sec, False)
403
self._metrics.put_metric('queue_size', qsize, False)
404
self._metrics.put_metric(
405
'time_elapsed', delta_seconds, False)
406
self._metrics.log_metrics()
407
self._metrics.reset_metrics()
409
self._prev_seconds = current_seconds
412
class GlobalCoordinator(GlobalWorkerCoordinator):
414
GlobalWorkerCoordinator.__init__(self)
417
def get_queue(self, queue_name, max_buffered_batches):
418
assert isinstance(max_buffered_batches, int)
419
if queue_name not in self._queues:
420
self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
421
return self._queues[queue_name]
423
def reset_data_input(self, namescope, name, net, batch_size):
424
log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
425
for c in self._coordinators:
426
if c._worker_name == name and c._state._namescope == namescope:
427
c._state._batch_size = batch_size
428
c._state._create_caffe2_ops(net)
431
class DataWorker(Worker):
441
Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
443
self._batch_size = batch_size
444
self._batch_feeder = batch_feeder
447
input_data = self._worker_fun(self._worker_id, self._batch_size)
449
self._batch_feeder.put(input_data, self._coordinator)
452
self._metrics.put_metric(
453
'fetcher_time', time.time() - self._start_time)
456
global_coordinator = GlobalCoordinator()
459
def enqueuer(coordinator, batch_feeder):
460
while coordinator.is_active():
461
batch_feeder._enqueue_batch(coordinator)