8
from caffe2.python import core, queue_util
9
from caffe2.python.dataio import Reader, Writer
10
from caffe2.python.net_builder import NetBuilder, ops
11
from caffe2.python.schema import as_record, Field
12
from caffe2.python.task import Node, Task, TaskGroup
17
Represents the result of a processor function. A processor can either
18
return an Output, or it can return a record, in which case an Output will be
19
created for it afterwards.
21
def __init__(self, nets=None, record=None, should_stop=None):
22
builder_children = NetBuilder.current().get()
23
assert nets is None or len(builder_children) == 0, (
24
'Cannot both use `ops` syntax and return a list of nets.')
26
nets = builder_children
27
if isinstance(nets, core.Net):
29
self.nets = [] if nets is None else list(nets)
30
self.record = None if record is None else as_record(record)
31
self.should_stop = should_stop
34
DEFAULT_QUEUE_CAPACITY = 100
37
def _init_output(output, capacity, global_init_net, global_exit_net):
39
out_queue = queue_util.Queue(
41
capacity if capacity is not None
42
else DEFAULT_QUEUE_CAPACITY))
43
writer = out_queue.writer()
44
elif isinstance(output, Writer):
45
assert capacity is None, 'capacity would not be used.'
48
elif hasattr(output, 'writer'):
49
assert capacity is None, 'capacity would not be used.'
51
writer = output.writer()
53
raise ValueError('output must be a reader, queue or stream.')
54
writer.setup_ex(global_init_net, global_exit_net)
55
return out_queue, writer
58
def make_processor(processor, reader=None):
60
return lambda rec: rec
61
elif isinstance(processor, core.Net):
62
return NetProcessor(processor)
64
if reader is not None and hasattr(processor, "schema_func"):
65
def processor_schema():
66
return processor.schema_func(reader)
68
processor.schema = processor_schema
72
def normalize_processor_output(output):
74
Allow for processors to return results in several formats.
75
TODO(azzolini): simplify once all processors use NetBuilder API.
77
if isinstance(output, Output):
78
""" Processor returned an Output. """
80
elif isinstance(output, Field):
81
""" Processor returned a record. """
82
return Output(record=output)
83
elif isinstance(output, tuple):
84
is_record_and_blob = (
86
isinstance(output[0], Field) and
87
isinstance(output[1], core.BlobReference))
88
if is_record_and_blob:
89
""" Processor returned (record, stop_blob) """
90
return Output(None, *output)
92
""" Processor returned (nets, record, stop_blob) """
93
return Output(*output)
95
""" Processor returned nets, no output """
100
input, output=None, num_threads=1, processor=None, name=None,
101
capacity=None, group=None, num_runtime_threads=1):
103
Given a Reader, Queue or DataStream in `input`, and optionally, a Writer,
104
Queue or DataStream in `output`, creates a Task that, when run, will
105
pipe the input into the output, using multiple parallel threads.
106
Additionally, if a processor is given, it will be called between reading
107
and writing steps, allowing it to transform the record.
110
input: either a Reader, Queue or DataStream that will be read
111
until a stop is signaled either by the reader or the
113
output: either a Writer, a Queue or a DataStream that will be
114
written to as long as neither reader nor writer signal
115
a stop condition. If output is not provided or is None,
116
a Queue is created with given `capacity` and written to.
117
num_threads: number of concurrent threads used for processing and
118
piping. If set to 0, no Task is created, and a
119
reader is returned instead -- the reader returned will
120
read from the reader passed in and process it.
121
** DEPRECATED **. Use `num_runtime_threads` instead.
122
This option will be removed once all readers/processors
123
support `num_runtime_threads`.
124
processor: (optional) function that takes an input record and
125
optionally returns a record; this will be called
126
between read and write steps. If the processor does
127
not return a record, a writer will not be instantiated.
128
Processor can also be a core.Net with input and output
129
records properly set. In that case, a NetProcessor is
130
instantiated, cloning the net for each of the threads.
131
name: (optional) name of the task to be created.
132
capacity: when output is not passed, a queue of given `capacity`
133
is created and written to.
134
group: (optional) explicitly add the created Task to this
135
TaskGroup, instead of using the currently active one.
136
num_runtime_threads: Similar to `num_threads`, but instead of expanding
137
the tasks with a `for` loop in python, does that at
138
runtime. This is preferable to `num_threads`, but some
139
processors/readers still require to be called multiple
143
Output Queue, DataStream, Reader, or None, depending on the parameters
146
result, _ = _pipe_step(
147
input, output, num_threads, processor, name, capacity, group,
153
input, output=None, num_threads=1, processor=None, name=None,
154
capacity=None, group=None, num_runtime_threads=1, final_outputs=None):
156
Similar to `pipe`, with the additional ability for the pipe Task to
157
return output values to the `Session` once done.
160
Tuple (out_queue, *task_outputs)
161
out_queue: same as return value of `pipe`.
162
task_outputs: TaskOutput object, fetchable from the client after
163
session.run() returns.
165
assert num_threads > 0
166
result, task = _pipe_step(
167
input, output, num_threads, processor, name, capacity, group,
168
num_runtime_threads, final_outputs)
170
if final_outputs is not None:
171
output = task.outputs()
172
if type(final_outputs) not in (list, tuple):
174
return result, output
177
def processor_name(processor):
178
if hasattr(processor, 'name'):
179
return processor.name
180
if hasattr(processor, 'func_name'):
181
if processor.func_name == '<lambda>':
182
return processor.__module__
183
if hasattr(processor, 'im_class'):
184
return '%s.%s' % (processor.im_class.__name__, processor.func_name)
185
return processor.func_name
186
return processor.__class__.__name__
189
def _runtime_threads_task(name, group, final_outputs, reader, num_threads,
191
node_name = str(Node.current())
192
profiler_name = "{0}/{1}/{2}/{3}/{4}".format(
196
processor_name(input) if input else "NoInput",
197
processor_name(output) if output else "NoOutput")
199
with Task(name=name, group=group, outputs=final_outputs,
200
num_instances=num_threads) as task:
201
global_exit_net = core.Net('pipe:exit')
202
global_init_net = core.Net('pipe:init')
203
reader.setup_ex(global_init_net, global_exit_net)
205
init_net = core.Net('pipe:instance:init')
206
exit_net = core.Net('pipe:instance:exit')
207
read_nets, status, rec = reader.read_record_ex(init_net, exit_net)
208
init_net.ConstantFill(
212
dtype=core.DataType.BOOL
216
out_queue, writer = _init_output(
217
output, capacity, global_init_net, global_exit_net)
218
write_nets, _ = writer.write_record_ex(
219
rec, init_net, exit_net, status)
224
with ops.task_init():
225
ops.net(global_init_net)
226
with ops.task_instance_init():
229
timer_start_net = core.Net('timer_start')
230
timer = timer_start_net.TimerBegin([], counter_name=profiler_name)
231
timer_end_net = core.Net('timer_end')
232
timer_end_net.TimerEnd(timer, [])
234
ops.net(core.execution_step(
236
[timer_start_net] + list(read_nets) + list(write_nets) +
238
should_stop_blob=status))
239
ops.net(timer_end_net)
241
with ops.task_instance_exit():
243
with ops.task_exit():
244
ops.net(global_exit_net)
246
return out_queue, task
249
def _static_threads_task(name, group, final_outputs, reader, num_threads,
251
node_name = str(Node.current())
252
profiler_name = "{0}/{1}/{2}/{3}/{4}".format(
256
processor_name(input) if input else "NoInput",
257
processor_name(output) if output else "NoOutput")
259
with Task(name=name, group=group, outputs=final_outputs) as task:
260
global_exit_net = core.Net('exit')
261
global_init_net = core.Net('init')
262
reader.setup_ex(global_init_net, global_exit_net)
268
for thread_id in range(num_threads):
269
with NetBuilder(name='t:%d' % thread_id) as nb:
270
init_net = core.Net('init')
271
exit_net = core.Net('exit')
272
read_nets, status, rec = reader.read_record_ex(
274
init_net.ConstantFill(
278
dtype=core.DataType.BOOL
285
with NetBuilder(_fullname=task.name):
286
out_queue, writer = _init_output(
287
output, capacity, global_init_net,
289
write_nets, _ = writer.write_record_ex(
290
rec, init_net, exit_net, status)
294
timer_start_net = core.Net('timer_start')
295
timer = timer_start_net.TimerBegin([], counter_name=profiler_name)
296
timer_end_net = core.Net('timer_end')
297
timer_end_net.TimerEnd(timer, [])
300
ops.net(core.execution_step(
302
[timer_start_net] + list(read_nets) + list(write_nets) +
304
should_stop_blob=status))
305
ops.net(timer_end_net)
307
steps.append(core.to_execution_step(nb))
308
ops.net(global_init_net)
309
ops.net(core.execution_step('body', steps, concurrent_substeps=True))
310
ops.net(global_exit_net)
311
return out_queue, task
315
input, output=None, num_threads=1, processor=None, name=None,
316
capacity=None, group=None, num_runtime_threads=None, final_outputs=None):
319
assert num_threads <= 1 or num_runtime_threads <= 1, (
320
'Only one of num_threads or num_runtime_threads must be set.')
322
if isinstance(input, Reader):
324
elif hasattr(input, 'reader'):
325
reader = input.reader()
328
'Input must be a reader, queue or stream. Got {}'.format(type(input)))
330
if processor is not None:
331
reader = ProcessingReader(reader, processor)
333
if num_threads == 0 or num_runtime_threads == 0:
334
assert output is None
337
if name is None and processor is not None:
338
name = processor_name(processor)
339
if name is None and output is not None:
340
name = 'pipe_into:%s' % processor_name(output)
342
name = 'pipe_from:%s' % processor_name(input)
345
return _static_threads_task(
346
name, group, final_outputs, reader, num_threads, output, capacity)
348
return _runtime_threads_task(
349
name, group, final_outputs, reader, num_runtime_threads, output,
353
class ProcessingReader(Reader):
355
Reader that reads from an upstream reader, calls the processor, and returns
356
the processed record.
358
def __init__(self, reader, processor):
359
Reader.__init__(self)
361
self.processor = make_processor(processor, reader)
364
return self.processor.schema()
366
def setup_ex(self, init_net, finish_net):
367
self.reader.setup_ex(init_net, finish_net)
369
def read_ex(self, init_net, exit_net):
370
read_nets, status, rec = self.reader.read_record_ex(init_net, exit_net)
375
with NetBuilder() as nb:
381
result = normalize_processor_output(self.processor(rec))
382
read_nets += result.nets
383
if result.should_stop or nb._stop_blob:
384
stop_net = core.Net('stop_net')
385
if result.should_stop:
386
stop_net.Or([status, result.should_stop], [status])
388
stop_net.Or([status, nb._stop_blob], [status])
389
read_nets.append(stop_net)
390
if hasattr(self.processor, 'setup'):
391
init_net.add_attribute(TaskGroup.LOCAL_SETUP, self.processor)
392
self._set_schema(result.record)
393
fields = result.record.field_blobs() if result.record else None
394
return read_nets, status, fields
399
Processor that clones a core.Net each time it's called, executing
400
the cloned net as the processor. It requires the Net to have input
401
and (optionally) output records set, with net.set_input_record() and
402
net.set_output_record().
404
def __init__(self, net, stop_signal=None, thread_init_nets=None, name=None):
405
assert isinstance(net, core.Net)
406
assert stop_signal is None or isinstance(
407
stop_signal, core.BlobReference)
408
self.name = name or str(net)
409
self.thread_init_nets = thread_init_nets or []
411
self._stop_signal = stop_signal
414
self._cloned_init_nets = []
417
return self.net.output_record()
419
def setup(self, init_net):
421
cloned_init_nets = self._cloned_init_nets
422
self._cloned_init_nets = []
423
return cloned_init_nets
425
def __call__(self, rec):
426
assert not self._frozen
427
prefix = NetBuilder.current().name + '/'
429
for net in self.thread_init_nets:
430
new_net, _ = core.clone_and_bind_net(
431
net, str(net) + prefix, prefix, blob_remap)
432
self._cloned_init_nets.append(new_net)
434
new_net, remappings = core.clone_and_bind_net(
435
self.net, str(self.net) + prefix, prefix, blob_remap, rec)
437
if self._stop_signal is None:
439
elif str(self._stop_signal) in remappings:
440
stop_signal = core.BlobReference(
441
remappings[str(self._stop_signal)],
444
stop_signal = self._stop_signal
446
self._blob_maps.append(remappings)
447
return Output([new_net], new_net.output_record(), stop_signal)
451
return self._blob_maps