pytorch

Форк
0
/
pipeline.py 
451 строка · 16.9 Кб
1
## @package pipeline
2
# Module caffe2.python.pipeline
3

4

5

6

7

8
from caffe2.python import core, queue_util
9
from caffe2.python.dataio import Reader, Writer
10
from caffe2.python.net_builder import NetBuilder, ops
11
from caffe2.python.schema import as_record, Field
12
from caffe2.python.task import Node, Task, TaskGroup
13

14

15
class Output:
16
    """
17
    Represents the result of a processor function. A processor can either
18
    return an Output, or it can return a record, in which case an Output will be
19
    created for it afterwards.
20
    """
21
    def __init__(self, nets=None, record=None, should_stop=None):
22
        builder_children = NetBuilder.current().get()
23
        assert nets is None or len(builder_children) == 0, (
24
            'Cannot both use `ops` syntax and return a list of nets.')
25
        if nets is None:
26
            nets = builder_children
27
        if isinstance(nets, core.Net):
28
            nets = [nets]
29
        self.nets = [] if nets is None else list(nets)
30
        self.record = None if record is None else as_record(record)
31
        self.should_stop = should_stop
32

33

34
DEFAULT_QUEUE_CAPACITY = 100
35

36

37
def _init_output(output, capacity, global_init_net, global_exit_net):
38
    if output is None:
39
        out_queue = queue_util.Queue(
40
            capacity=(
41
                capacity if capacity is not None
42
                else DEFAULT_QUEUE_CAPACITY))
43
        writer = out_queue.writer()
44
    elif isinstance(output, Writer):
45
        assert capacity is None, 'capacity would not be used.'
46
        out_queue = None
47
        writer = output
48
    elif hasattr(output, 'writer'):
49
        assert capacity is None, 'capacity would not be used.'
50
        out_queue = output
51
        writer = output.writer()
52
    else:
53
        raise ValueError('output must be a reader, queue or stream.')
54
    writer.setup_ex(global_init_net, global_exit_net)
55
    return out_queue, writer
56

57

58
def make_processor(processor, reader=None):
59
    if processor is None:
60
        return lambda rec: rec
61
    elif isinstance(processor, core.Net):
62
        return NetProcessor(processor)
63
    else:
64
        if reader is not None and hasattr(processor, "schema_func"):
65
            def processor_schema():
66
                return processor.schema_func(reader)
67

68
            processor.schema = processor_schema
69
        return processor
70

71

72
def normalize_processor_output(output):
73
    """
74
    Allow for processors to return results in several formats.
75
    TODO(azzolini): simplify once all processors use NetBuilder API.
76
    """
77
    if isinstance(output, Output):
78
        """ Processor returned an Output. """
79
        return output
80
    elif isinstance(output, Field):
81
        """ Processor returned a record. """
82
        return Output(record=output)
83
    elif isinstance(output, tuple):
84
        is_record_and_blob = (
85
            len(output) == 2 and
86
            isinstance(output[0], Field) and
87
            isinstance(output[1], core.BlobReference))
88
        if is_record_and_blob:
89
            """ Processor returned (record, stop_blob) """
90
            return Output(None, *output)
91
        else:
92
            """ Processor returned (nets, record, stop_blob) """
93
            return Output(*output)
94
    else:
95
        """ Processor returned nets, no output """
96
        return Output(output)
97

98

99
def pipe(
100
        input, output=None, num_threads=1, processor=None, name=None,
101
        capacity=None, group=None, num_runtime_threads=1):
102
    """
103
    Given a Reader, Queue or DataStream in `input`, and optionally, a Writer,
104
    Queue or DataStream in `output`, creates a Task that, when run, will
105
    pipe the input into the output, using multiple parallel threads.
106
    Additionally, if a processor is given, it will be called between reading
107
    and writing steps, allowing it to transform the record.
108

109
    Args:
110
        input:       either a Reader, Queue or DataStream that will be read
111
                     until a stop is signaled either by the reader or the
112
                     writer.
113
        output:      either a Writer, a Queue or a DataStream that will be
114
                     written to as long as neither reader nor writer signal
115
                     a stop condition. If output is not provided or is None,
116
                     a Queue is created with given `capacity` and written to.
117
        num_threads: number of concurrent threads used for processing and
118
                     piping. If set to 0, no Task is created, and a
119
                     reader is returned instead -- the reader returned will
120
                     read from the reader passed in and process it.
121
                     ** DEPRECATED **. Use `num_runtime_threads` instead.
122
                     This option will be removed once all readers/processors
123
                     support `num_runtime_threads`.
124
        processor:   (optional) function that takes an input record and
125
                     optionally returns a record; this will be called
126
                     between read and write steps. If the processor does
127
                     not return a record, a writer will not be instantiated.
128
                     Processor can also be a core.Net with input and output
129
                     records properly set. In that case, a NetProcessor is
130
                     instantiated, cloning the net for each of the threads.
131
        name:        (optional) name of the task to be created.
132
        capacity:    when output is not passed, a queue of given `capacity`
133
                     is created and written to.
134
        group:       (optional) explicitly add the created Task to this
135
                     TaskGroup, instead of using the currently active one.
136
        num_runtime_threads: Similar to `num_threads`, but instead of expanding
137
                     the tasks with a `for` loop in python, does that at
138
                     runtime. This is preferable to `num_threads`, but some
139
                     processors/readers still require to be called multiple
140
                     times in python.
141

142
    Returns:
143
        Output Queue, DataStream, Reader, or None, depending on the parameters
144
        passed.
145
    """
146
    result, _ = _pipe_step(
147
        input, output, num_threads, processor, name, capacity, group,
148
        num_runtime_threads)
149
    return result
150

151

152
def pipe_and_output(
153
        input, output=None, num_threads=1, processor=None, name=None,
154
        capacity=None, group=None, num_runtime_threads=1, final_outputs=None):
155
    """
156
    Similar to `pipe`, with the additional ability for the pipe Task to
157
    return output values to the `Session` once done.
158

159
    Returns:
160
        Tuple (out_queue, *task_outputs)
161
            out_queue:    same as return value of `pipe`.
162
            task_outputs: TaskOutput object, fetchable from the client after
163
                          session.run() returns.
164
    """
165
    assert num_threads > 0
166
    result, task = _pipe_step(
167
        input, output, num_threads, processor, name, capacity, group,
168
        num_runtime_threads, final_outputs)
169
    output = None
170
    if final_outputs is not None:
171
        output = task.outputs()
172
        if type(final_outputs) not in (list, tuple):
173
            output = output[0]
174
    return result, output
175

176

177
def processor_name(processor):
178
    if hasattr(processor, 'name'):
179
        return processor.name
180
    if hasattr(processor, 'func_name'):
181
        if processor.func_name == '<lambda>':
182
            return processor.__module__
183
        if hasattr(processor, 'im_class'):
184
            return '%s.%s' % (processor.im_class.__name__, processor.func_name)
185
        return processor.func_name
186
    return processor.__class__.__name__
187

188

189
def _runtime_threads_task(name, group, final_outputs, reader, num_threads,
190
                          output, capacity):
191
    node_name = str(Node.current())
192
    profiler_name = "{0}/{1}/{2}/{3}/{4}".format(
193
        node_name,
194
        "pipe",
195
        name,
196
        processor_name(input) if input else "NoInput",
197
        processor_name(output) if output else "NoOutput")
198

199
    with Task(name=name, group=group, outputs=final_outputs,
200
              num_instances=num_threads) as task:
201
        global_exit_net = core.Net('pipe:exit')
202
        global_init_net = core.Net('pipe:init')
203
        reader.setup_ex(global_init_net, global_exit_net)
204

205
        init_net = core.Net('pipe:instance:init')
206
        exit_net = core.Net('pipe:instance:exit')
207
        read_nets, status, rec = reader.read_record_ex(init_net, exit_net)
208
        init_net.ConstantFill(
209
            [], [status],
210
            shape=[],
211
            value=False,
212
            dtype=core.DataType.BOOL
213
        )
214

215
        if rec is not None:
216
            out_queue, writer = _init_output(
217
                output, capacity, global_init_net, global_exit_net)
218
            write_nets, _ = writer.write_record_ex(
219
                rec, init_net, exit_net, status)
220
        else:
221
            out_queue = None
222
            write_nets = []
223

224
        with ops.task_init():
225
            ops.net(global_init_net)
226
        with ops.task_instance_init():
227
            ops.net(init_net)
228

229
        timer_start_net = core.Net('timer_start')
230
        timer = timer_start_net.TimerBegin([], counter_name=profiler_name)
231
        timer_end_net = core.Net('timer_end')
232
        timer_end_net.TimerEnd(timer, [])
233

234
        ops.net(core.execution_step(
235
            'body',
236
            [timer_start_net] + list(read_nets) + list(write_nets) +
237
            [timer_end_net],
238
            should_stop_blob=status))
239
        ops.net(timer_end_net)
240

241
        with ops.task_instance_exit():
242
            ops.net(exit_net)
243
        with ops.task_exit():
244
            ops.net(global_exit_net)
245

246
    return out_queue, task
247

248

249
def _static_threads_task(name, group, final_outputs, reader, num_threads,
250
                         output, capacity):
251
    node_name = str(Node.current())
252
    profiler_name = "{0}/{1}/{2}/{3}/{4}".format(
253
        node_name,
254
        "pipe",
255
        name,
256
        processor_name(input) if input else "NoInput",
257
        processor_name(output) if output else "NoOutput")
258

259
    with Task(name=name, group=group, outputs=final_outputs) as task:
260
        global_exit_net = core.Net('exit')
261
        global_init_net = core.Net('init')
262
        reader.setup_ex(global_init_net, global_exit_net)
263

264
        out_queue = None
265
        writer = None
266

267
        steps = []
268
        for thread_id in range(num_threads):
269
            with NetBuilder(name='t:%d' % thread_id) as nb:
270
                init_net = core.Net('init')
271
                exit_net = core.Net('exit')
272
                read_nets, status, rec = reader.read_record_ex(
273
                    init_net, exit_net)
274
                init_net.ConstantFill(
275
                    [], [status],
276
                    shape=[],
277
                    value=False,
278
                    dtype=core.DataType.BOOL
279
                )
280

281
                if rec is not None:
282
                    if writer is None:
283
                        # hack so that the out queue gets the right name prefix
284
                        # (otherwise they would be prefixed with the thread id)
285
                        with NetBuilder(_fullname=task.name):
286
                            out_queue, writer = _init_output(
287
                                output, capacity, global_init_net,
288
                                global_exit_net)
289
                    write_nets, _ = writer.write_record_ex(
290
                        rec, init_net, exit_net, status)
291
                else:
292
                    write_nets = []
293

294
                timer_start_net = core.Net('timer_start')
295
                timer = timer_start_net.TimerBegin([], counter_name=profiler_name)
296
                timer_end_net = core.Net('timer_end')
297
                timer_end_net.TimerEnd(timer, [])
298

299
                ops.net(init_net)
300
                ops.net(core.execution_step(
301
                    'body',
302
                    [timer_start_net] + list(read_nets) + list(write_nets) +
303
                    [timer_end_net],
304
                    should_stop_blob=status))
305
                ops.net(timer_end_net)
306
                ops.net(exit_net)
307
            steps.append(core.to_execution_step(nb))
308
        ops.net(global_init_net)
309
        ops.net(core.execution_step('body', steps, concurrent_substeps=True))
310
        ops.net(global_exit_net)
311
    return out_queue, task
312

313

314
def _pipe_step(
315
        input, output=None, num_threads=1, processor=None, name=None,
316
        capacity=None, group=None, num_runtime_threads=None, final_outputs=None):
317
    """
318
    """
319
    assert num_threads <= 1 or num_runtime_threads <= 1, (
320
        'Only one of num_threads or num_runtime_threads must be set.')
321

322
    if isinstance(input, Reader):
323
        reader = input
324
    elif hasattr(input, 'reader'):
325
        reader = input.reader()
326
    else:
327
        raise ValueError(
328
            'Input must be a reader, queue or stream. Got {}'.format(type(input)))
329

330
    if processor is not None:
331
        reader = ProcessingReader(reader, processor)
332

333
    if num_threads == 0 or num_runtime_threads == 0:
334
        assert output is None
335
        return reader, None
336

337
    if name is None and processor is not None:
338
        name = processor_name(processor)
339
    if name is None and output is not None:
340
        name = 'pipe_into:%s' % processor_name(output)
341
    if name is None:
342
        name = 'pipe_from:%s' % processor_name(input)
343

344
    if num_threads > 1:
345
        return _static_threads_task(
346
            name, group, final_outputs, reader, num_threads, output, capacity)
347
    else:
348
        return _runtime_threads_task(
349
            name, group, final_outputs, reader, num_runtime_threads, output,
350
            capacity)
351

352

353
class ProcessingReader(Reader):
354
    """
355
    Reader that reads from an upstream reader, calls the processor, and returns
356
    the processed record.
357
    """
358
    def __init__(self, reader, processor):
359
        Reader.__init__(self)
360
        self.reader = reader
361
        self.processor = make_processor(processor, reader)
362

363
    def schema(self):
364
        return self.processor.schema()
365

366
    def setup_ex(self, init_net, finish_net):
367
        self.reader.setup_ex(init_net, finish_net)
368

369
    def read_ex(self, init_net, exit_net):
370
        read_nets, status, rec = self.reader.read_record_ex(init_net, exit_net)
371
        # We don't use status as stop_blob of NetBuilder it's not guarantee that
372
        # it would end up being the true stob_blob. For example,
373
        # ReaderWithLimitBase doesn't pass the status through but rather copy
374
        # from it.
375
        with NetBuilder() as nb:
376
            # Current NetBuilder is optionally used inside the processor,
377
            # then its children are retrieved inside of
378
            # normalize_processor_output.
379
            # Once readers and writers also use NetBuilder,
380
            # this logic will be more natural.
381
            result = normalize_processor_output(self.processor(rec))
382
        read_nets += result.nets
383
        if result.should_stop or nb._stop_blob:
384
            stop_net = core.Net('stop_net')
385
            if result.should_stop:
386
                stop_net.Or([status, result.should_stop], [status])
387
            if nb._stop_blob:
388
                stop_net.Or([status, nb._stop_blob], [status])
389
            read_nets.append(stop_net)
390
        if hasattr(self.processor, 'setup'):
391
            init_net.add_attribute(TaskGroup.LOCAL_SETUP, self.processor)
392
        self._set_schema(result.record)
393
        fields = result.record.field_blobs() if result.record else None
394
        return read_nets, status, fields
395

396

397
class NetProcessor:
398
    """
399
    Processor that clones a core.Net each time it's called, executing
400
    the cloned net as the processor. It requires the Net to have input
401
    and (optionally) output records set, with net.set_input_record() and
402
    net.set_output_record().
403
    """
404
    def __init__(self, net, stop_signal=None, thread_init_nets=None, name=None):
405
        assert isinstance(net, core.Net)
406
        assert stop_signal is None or isinstance(
407
            stop_signal, core.BlobReference)
408
        self.name = name or str(net)
409
        self.thread_init_nets = thread_init_nets or []
410
        self.net = net
411
        self._stop_signal = stop_signal
412
        self._blob_maps = []
413
        self._frozen = False
414
        self._cloned_init_nets = []
415

416
    def schema(self):
417
        return self.net.output_record()
418

419
    def setup(self, init_net):
420
        self._frozen = True
421
        cloned_init_nets = self._cloned_init_nets
422
        self._cloned_init_nets = []
423
        return cloned_init_nets
424

425
    def __call__(self, rec):
426
        assert not self._frozen
427
        prefix = NetBuilder.current().name + '/'
428
        blob_remap = {}
429
        for net in self.thread_init_nets:
430
            new_net, _ = core.clone_and_bind_net(
431
                net, str(net) + prefix, prefix, blob_remap)
432
            self._cloned_init_nets.append(new_net)
433

434
        new_net, remappings = core.clone_and_bind_net(
435
            self.net, str(self.net) + prefix, prefix, blob_remap, rec)
436

437
        if self._stop_signal is None:
438
            stop_signal = None
439
        elif str(self._stop_signal) in remappings:
440
            stop_signal = core.BlobReference(
441
                remappings[str(self._stop_signal)],
442
                net=new_net)
443
        else:
444
            stop_signal = self._stop_signal
445

446
        self._blob_maps.append(remappings)
447
        return Output([new_net], new_net.output_record(), stop_signal)
448

449
    def blob_maps(self):
450
        self._frozen = True
451
        return self._blob_maps
452

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.