pytorch

Форк
0
692 строки · 23.6 Кб
1
## @package task
2
# Module caffe2.python.task
3

4
from caffe2.python import core, context
5
from caffe2.python.schema import Field, from_blob_list
6
from collections import defaultdict
7
from copy import copy
8

9

10
def _merge_node_kwargs(a, b):
11
    # TODO(azzolini): consistency checks
12
    if a is None:
13
        return b
14
    if b is None:
15
        return a
16
    c = copy(a)
17
    c.update(b)
18
    return c
19

20

21
class Cluster(context.DefaultManaged):
22
    """
23
    Context that keeps track of all the node names used.
24
    Users shouldn't have to use them directly, since a Cluster is automatically
25
    generated at the first usage of 'Node'.
26
    """
27

28
    def __init__(self):
29
        # list instead of set to keep order
30
        self._nodes = []
31
        self._node_kwargs = {}
32

33
    def add_node(self, node):
34
        if str(node) not in self._nodes:
35
            self._nodes.append(str(node))
36
        self._node_kwargs[str(node)] = _merge_node_kwargs(
37
            node.kwargs(),
38
            self._node_kwargs.get(str(node)))
39

40
    def nodes(self):
41
        """
42
        Returns the list of unique node names used within this context.
43
        """
44
        return self._nodes
45

46
    def node_kwargs(self):
47
        return self._node_kwargs
48

49
    def __repr__(self):
50
        return "Cluster(nodes={}, node_kwargs={})".format(
51
            self.nodes(), self.node_kwargs())
52

53

54
class Node(context.DefaultManaged):
55
    """
56
    A Node context is used to indicate that all Tasks instantiated within will
57
    run on the given node name. (Only the name of the node actually counts.)
58
    Example:
59

60
        with TaskGroup() as tg:
61
            with Node('node1'):
62
                s1 = execution_step(...)
63
                Task(step=s1)
64
            with Node('node2'):
65
                s2 = execution_step(...)
66
            with Node('node1'):
67
                s3 = execution_step(...)
68

69
        In this example, all three execution steps will run in parallel.
70
        Moreover, s1 and s3 will run on the same node, and can see each
71
        others blobs.
72

73
        Additionally, a Node can be passed implementation-specific kwargs,
74
        in order to specify properties of the node.
75
    """
76

77
    def __init__(self, node='local', **kwargs):
78
        self._name = str(node)
79
        self._kwargs = kwargs
80
        Cluster.current().add_node(self)
81

82
    def __str__(self):
83
        return self._name
84

85
    def __repr__(self):
86
        return "Node(name={}, kwargs={})".format(self._name, self._kwargs)
87

88
    def kwargs(self):
89
        return self._kwargs
90

91

92
class WorkspaceType:
93
    """
94
    Determines whether tasks of a TaskGroup will run directly at the global
95
    workspace, which is kept alive across runs, or whether a new child
96
    workspace will be created for the run and destroyed afterwards.
97
    """
98
    PRIVATE = 'private'
99
    GLOBAL = 'global'
100

101

102
def get_setup_nets(key, steps_or_nets, target):
103
    init_net = core.Net(key + '/init')
104
    exit_net = core.Net(key + '/exit')
105
    init_nets = []
106
    exit_nets = []
107
    objs = []
108
    for step_or_net in steps_or_nets:
109
        if hasattr(step_or_net, 'get_all_attributes'):
110
            objs += step_or_net.get_all_attributes(key)
111
        elif hasattr(step_or_net, 'get_attributes'):
112
            objs += step_or_net.get_attributes(key)
113
    for obj in objs:
114
        # these are needed in order to allow nesting of TaskGroup, which
115
        # is a feature not yet implemented.
116
        if hasattr(obj, '_setup_used') and obj._setup_used:
117
            continue
118
        if hasattr(obj, '_setup_target') and obj._setup_target != target:
119
            continue
120
        if hasattr(obj, 'setup'):
121
            nets = obj.setup(init_net)
122
            if isinstance(nets, (list, tuple)):
123
                init_nets += nets
124
            elif isinstance(nets, (core.Net, core.ExecutionStep)):
125
                init_nets.append(nets)
126
            elif nets is not None:
127
                raise TypeError('Unsupported type for setup: %s' % type(nets))
128
            obj._setup_used = True
129
        if hasattr(obj, 'exit'):
130
            nets = obj.exit(exit_net)
131
            if isinstance(nets, (list, tuple)):
132
                exit_nets += nets
133
            elif isinstance(nets, (core.Net, core.ExecutionStep)):
134
                exit_nets.append(nets)
135
            elif nets is not None:
136
                raise TypeError('Unsupported type for setup: %s' % type(nets))
137
            obj._setup_used = True
138

139
    if len(init_net.Proto().op) > 0:
140
        init_nets.insert(0, init_net)
141
    if len(exit_net.Proto().op) > 0:
142
        exit_nets.insert(0, exit_net)
143
    return init_nets, exit_nets
144

145

146
def add_setup_steps(step, init_nets, exit_nets, name):
147
    if not init_nets and not exit_nets:
148
        return step
149
    steps = []
150
    if init_nets:
151
        steps.append(core.execution_step('%s:init' % name, init_nets))
152
    steps.append(step)
153
    if len(exit_nets) > 0:
154
        steps.append(core.execution_step('%s:exit' % name, exit_nets))
155
    return core.execution_step(name, steps)
156

157

158
class TaskGroup(context.Managed):
159
    """
160
    Context that gathers tasks which will run concurrently, potentially on
161
    multiple nodes. All tasks in the same node will share the same workspace
162
    and thus can share blobs, while tasks running in different nodes won't
163
    be able to directly share data.
164

165
    All tasks of the task group will start concurrently, and the task group
166
    will finish execution when the last task of the group finishes.
167

168
    Example:
169
        # suppose that s1 ... s5 are execution steps or nets.
170
        with TaskGroup() as tg:
171
            # these tasks go to default node 'local'
172
            Task(step=s1)
173
            Task(step=s2)
174

175
            with Node('n2'):
176
                Task(step=s3)
177
            with Node('n1'):
178
                Task(step=s4)
179
            with Node('n2'):
180
                Task(step=s5)
181

182
        # this will run all steps in parallel.
183
        # s1 and s2 will run at default node 'local'
184
        # s3 and s5 will run at node 'n2'
185
        # s4 will run at node 'n1'
186
        session.run(tg)
187
    """
188
    LOCAL_SETUP = 'local_setup'
189

190
    def __init__(self, workspace_type=None):
191
        self._plan_cache = None
192
        self._tasks = []
193
        self._already_used = False
194
        self._prev_active = None
195
        self._tasks_to_add = []
196
        self._report_nets = {}
197
        self._report_steps = []
198
        self._workspace_type = workspace_type
199
        self._tasks_by_node = None
200
        self._remote_nets = []
201

202
    def add_remote_net(self, net):
203
        self._remote_nets.append(net)
204

205
    def remote_nets(self):
206
        return self._remote_nets
207

208
    def add(self, task):
209
        assert not self._already_used, (
210
            'Cannot add Task to an already used TaskGroup.')
211
        assert (
212
            self._workspace_type is None or
213
            task._workspace_type is None or
214
            self._workspace_type == task._workspace_type)
215
        if task._workspace_type is None:
216
            task._workspace_type = (
217
                self._workspace_type or WorkspaceType.PRIVATE)
218
        if self._workspace_type is None:
219
            self._workspace_type = task._workspace_type
220
        task._notify_used()
221
        self._tasks.append(task)
222

223
    def tasks(self):
224
        for task in self._tasks_to_add:
225
            self.add(task)
226
        self._tasks_to_add = []
227
        self._already_used = True
228
        return self._tasks
229

230
    def num_registered_tasks(self):
231
        return len(self._tasks_to_add) + len(self._tasks)
232

233
    def used_nodes(self):
234
        # use list to keep order
235
        used = []
236
        for task in self._tasks + self._tasks_to_add:
237
            if task.node not in used:
238
                used.append(task.node)
239
        return used
240

241
    def report_step(self, step=None, node=None, interval_ms=1000):
242
        """
243
        Add a "report step" to this TaskGroup. This step will run repeatedly
244
        every `interval_ms` milliseconds for the duration of the TaskGroup
245
        execution on each of the nodes. It is guaranteed that this step
246
        will be run at least once after every Task in the node has finished.
247
        """
248
        step = core.to_execution_step(step)
249
        step.RunEveryMillis(interval_ms)
250
        self._report_steps.append((str(node or Node.current(node)), step))
251

252
    def report_net(self, net=None, node=None, report_interval=5):
253
        """
254
        DEPRECATED. Use report_step instead.
255
        """
256
        node = str(node or Node.current(node))
257
        assert net is None or node not in self._report_nets
258
        if node not in self._report_nets:
259
            self._report_nets[node] = (
260
                net if net else core.Net('%s/reporter' % node),
261
                report_interval)
262
        return self._report_nets[node][0]
263

264
    def tasks_by_node(self, node_remap=None):
265
        # tasks_by_node can't be called twice because the setup won't
266
        # work properly a second time.
267
        node_map = {}
268
        for task in self.tasks():
269
            node_map[task.node] =\
270
                node_remap(task.node) if node_remap else task.node
271
        if self._tasks_by_node is not None:
272
            tasks_by_node, prev_node_map = self._tasks_by_node
273
            assert prev_node_map == node_map, (
274
                'Cannot call tasks_by_node multiple times.')
275
            return tasks_by_node
276

277
        # now we have report_steps. report_net is deprecated
278
        for node, (net, interval) in self._report_nets.items():
279
            self.report_step(net, node=node, interval_ms=interval * 1000)
280
        self._report_nets = {}
281

282
        tasks_by_node = defaultdict(list)
283
        for task in self.tasks():
284
            mapped_node = node_map[task.node]
285
            tasks_by_node[mapped_node].append(task)
286

287
        report_steps_by_node = defaultdict(list)
288
        for original_node, step in self._report_steps:
289
            report_steps_by_node[node_map[original_node]].append(step)
290

291
        grouped_by_node = TaskGroup()
292
        for node, tasks in tasks_by_node.items():
293
            report_steps = report_steps_by_node[node]
294
            node_inits, node_exits = get_setup_nets(
295
                TaskGroup.LOCAL_SETUP,
296
                [t.get_step() for t in tasks] + report_steps,
297
                self)
298
            # shortcut for single task with no queue
299
            steps = report_steps
300
            outputs = []
301
            grouped_workspace_type = WorkspaceType.PRIVATE
302
            for task in tasks:
303
                step = task.get_step()
304
                step.SetCreateWorkspace(
305
                    task.workspace_type() == WorkspaceType.PRIVATE)
306
                if step is not None:
307
                    steps.append(step)
308
                outputs += task.outputs()
309
                # If any of the tasks in the node uses the global workspace,
310
                # then set the grouped task to use the global workspace as well
311
                if task.workspace_type() == WorkspaceType.GLOBAL:
312
                    grouped_workspace_type = WorkspaceType.GLOBAL
313
            if len(steps) == 0:
314
                steps.append(core.execution_step('empty', []))
315
            if len(steps) == 1:
316
                step = steps[0]
317
            else:
318
                step = core.execution_step(
319
                    '%s:body' % node, steps, concurrent_substeps=True)
320
            if len(node_inits) > 0 or len(node_exits) > 0:
321
                steps = []
322
                if len(node_inits) > 0:
323
                    steps.append(
324
                        core.execution_step('%s:init' % node, node_inits))
325
                steps.append(step)
326
                if len(node_exits) > 0:
327
                    steps.append(
328
                        core.execution_step('%s:exit' % node, node_exits))
329
                step = core.execution_step(node, steps)
330
            Task(
331
                node=node, step=step, outputs=outputs,
332
                name='grouped_by_node',
333
                group=grouped_by_node, workspace_type=grouped_workspace_type)
334
        self._tasks_by_node = (grouped_by_node, node_map)
335
        return grouped_by_node
336

337
    def to_task(self, node=None):
338
        node = str(Node.current(node))
339
        tasks = self.tasks_by_node(lambda x: node).tasks()
340
        if len(tasks) == 0:
341
            return Task()
342
        return tasks[0]
343

344
    def workspace_type(self):
345
        return self._workspace_type
346

347
    def __repr__(self):
348
        return "TaskGroup(tasks={}, workspace_type={}, remote_nets={})".format(
349
            self._tasks + self._tasks_to_add,
350
            self.workspace_type(),
351
            self.remote_nets())
352

353

354
class TaskOutput:
355
    """
356
    Represents the output of a task. An output can be a blob,
357
    a list of blob, or a record.
358
    """
359

360
    def __init__(self, names):
361
        self._schema = None
362
        self._is_scalar = False
363
        if isinstance(names, Field):
364
            self._schema = names
365
            names = self._schema.field_blobs()
366
        self._is_scalar = type(names) not in (tuple, list)
367
        if self._is_scalar:
368
            names = [names]
369
        self.names = names
370
        self._values = None
371

372
    def set(self, values, _fetch_func=None):
373
        assert len(values) == len(self.names)
374
        self._values = values
375
        self._fetch_func = _fetch_func
376

377
    def get(self):
378
        assert self._values is not None, 'Output value not set yet.'
379
        if self._is_scalar:
380
            return self._values[0]
381
        elif self._schema:
382
            return from_blob_list(self._schema, self._values)
383
        else:
384
            return self._values
385

386
    def fetch(self):
387
        assert self._fetch_func is not None, (
388
            'Cannot fetch value for this output.')
389
        fetched_vals = [self._fetch_func(v) for v in self._values]
390
        if self._is_scalar:
391
            return fetched_vals[0]
392
        elif self._schema:
393
            return from_blob_list(self._schema, fetched_vals)
394
        else:
395
            return fetched_vals
396

397
    def __repr__(self):
398
        return "TaskOutput(names={}, values={})".format(self.names, self._values)
399

400

401
def final_output(blob_or_record):
402
    """
403
    Adds an output to the current Task, or if no task is active,
404
    create a dummy task that returns the given blob or record
405
    to the client. This will return the value of the blob or record when
406
    the last task of the TaskGroup for a given node finishes.
407
    """
408
    cur_task = Task.current(required=False) or Task()
409
    return cur_task.add_output(blob_or_record)
410

411

412
class TaskOutputList:
413
    """ Keeps a list of outputs for a task """
414
    def __init__(self, outputs=None):
415
        self.outputs = outputs or []
416

417
    def names(self):
418
        """
419
        Retrive the output names.
420
        TODO(azzolini): make this schema-based.
421
        """
422
        names = []
423
        for o in self.outputs:
424
            names += o.names
425
        return names
426

427
    def set_values(self, values, _fetch_func=None):
428
        offset = 0
429
        for o in self.outputs:
430
            num = len(o.names)
431
            o.set(values[offset:offset + num], _fetch_func)
432
            offset += num
433
        assert offset == len(values), 'Wrong number of output values.'
434

435
    def __repr__(self):
436
        return "TaskOutputList(outputs={})".format(self.outputs)
437

438

439
class Task(context.Managed):
440
    """
441
    A Task is composed of an execution step and zero or more outputs.
442
    Tasks are executed in the context of a TaskGroup, which, in turn, can
443
    be run by a Session.
444

445
    Task outputs are fetched by the session at the end of the run.
446

447
    The recommended way of creating a task is by using `net_builder.ops`.
448
    Example:
449

450
        from net_builder import ops
451
        with Node('trainer'), Task(name='my_task', num_instances=2):
452
            with ops.task_init():
453
                globl = ops.Const(0)
454
            with ops.task_instance_init():
455
                local = ops.Const(0)
456
            with ops.loop(100):
457
                ops.Copy(globl, local)
458
            with ops.task_instance_exit():
459
                ops.Add([globl, local], [globl])
460
            with ops.task_exit():
461
                ops.Mul([globl, globl], [globl])
462

463
    The task above will create 2 instances that will run in parallel.
464
    Each instance will copy `local` to `globl` 100 times, Then Add `local`
465
    to `globl` once. The `Mul` will only execute once, after all the instances
466
    of the task have finished.
467
    """
468

469
    # TASK_SETUP runs once per task, before/after all
470
    # concurrent task instances start/finish.
471
    TASK_SETUP = 'task_setup'
472
    # Setup will run once for each instance of the task.
473
    TASK_INSTANCE_SETUP = 'task_instance_setup'
474
    REPORT_STEP = 'report_step'
475
    _global_names_used = set()
476

477
    @staticmethod
478
    def _get_next_name(node, group, name):
479
        basename = str(node) + '/' + str(name)
480
        names_used = (
481
            Task._global_names_used
482
            if group is None else
483
            set(t.name for t in group._tasks_to_add))
484
        cur_name = basename
485
        i = 0
486
        while cur_name in names_used:
487
            i += 1
488
            cur_name = '%s:%d' % (basename, i)
489
        return cur_name
490

491
    def __init__(
492
            self, step=None, outputs=None,
493
            workspace_type=None, group=None, node=None, name=None,
494
            num_instances=None):
495
        """
496
        Instantiate a Task and add it to the current TaskGroup and Node.
497

498
        Args:
499
           step:    If provided, this task will run this ExecutionStep.
500
           outputs: If provided, the task will return the provided outputs
501
                    to the client at completion time.
502
           node:    If provided, force task execution on the given node.
503
           name:    Name of the Task.
504
           num_instances: If provided, this task will be cloned num_instances
505
                          times at runtime, and all instances will run
506
                          concurrently.
507
        """
508
        if not name and isinstance(step, core.ExecutionStep):
509
            name = step.Proto().name
510
        if not name:
511
            name = 'task'
512
        # register this node name with active context
513
        self.node = str(Node.current(None if node is None else Node(node)))
514
        self.group = TaskGroup.current(group, required=False)
515

516
        self.name = Task._get_next_name(self.node, self.group, name)
517

518
        # may need to be temporarily removed later if Task used as a context
519
        if self.group is not None:
520
            self.group._tasks_to_add.append(self)
521

522
        self._already_used = False
523
        self._step = None
524
        self._step_with_setup = None
525
        self._outputs = []
526
        if step is not None:
527
            self.set_step(step)
528
        if outputs is not None:
529
            self.add_outputs(outputs)
530

531
        self._pipeline = None
532
        self._is_pipeline_context = False
533
        self._workspace_type = workspace_type
534
        self._report_net = None
535
        self._num_instances = num_instances
536

537
    def __enter__(self):
538
        super().__enter__()
539

540
        # temporarily remove from _tasks_to_add to ensure correct order
541
        if self.group is not None:
542
            self.group._tasks_to_add.remove(self)
543
        self._assert_not_used()
544
        assert self._step is None, 'This Task already has an execution step.'
545
        from caffe2.python import net_builder
546
        self._net_builder = net_builder.NetBuilder(_fullname=self.name)
547
        self._net_builder.__enter__()
548
        return self
549

550
    def __exit__(self, type, value, traceback):
551
        super().__exit__(type, value, traceback)
552

553
        self._net_builder.__exit__(type, value, traceback)
554
        if type is None:
555
            self.set_step(self._net_builder)
556
        if self.group is not None:
557
            self.group._tasks_to_add.append(self)
558
        self._net_builder = None
559

560
    def workspace_type(self):
561
        return self._workspace_type
562

563
    def _assert_not_used(self):
564
        assert not self._already_used, (
565
            'Cannot modify task since it is already been used.')
566

567
    def add_output(self, output):
568
        self._assert_not_used()
569
        output = (
570
            output if isinstance(output, TaskOutput) else TaskOutput(output))
571
        self._outputs.append(output)
572
        return output
573

574
    def add_outputs(self, outputs):
575
        self._assert_not_used()
576
        if type(outputs) not in (list, tuple):
577
            return self.add_output(outputs)
578
        else:
579
            return [self.add_output(output) for output in outputs]
580

581
    def set_step(self, step):
582
        self._assert_not_used()
583
        self._step = core.to_execution_step(step)
584

585
    def get_step(self):
586
        if self._step_with_setup is not None:
587
            return self._step_with_setup
588

589
        if self._step is None:
590
            self._step_with_setup = core.execution_step(self.name, [])
591
            return self._step_with_setup
592

593
        report_steps = [
594
            s
595
            for s in self._step.get_all_attributes(Task.REPORT_STEP)
596
            if not hasattr(s, '_report_step_used')
597
        ]
598
        for step in report_steps:
599
            step._report_step_used = True
600
            if not step.Proto().run_every_ms:
601
                step.RunEveryMillis(1000)
602
        task_init_nets, task_exit_nets = get_setup_nets(
603
            Task.TASK_SETUP, [self._step] + report_steps, self)
604
        instance_init_nets, instance_exit_nets = get_setup_nets(
605
            Task.TASK_INSTANCE_SETUP, [self._step] + report_steps, self)
606
        if len(self._outputs) == 0:
607
            output_net = core.Net('%s:output' % self.name)
608
            self.add_output(output_net.ConstantFill(
609
                [], 1, dtype=core.DataType.INT32, value=0))
610
            task_exit_nets.append(output_net)
611

612
        # Add instance-level report steps
613
        body = self._step if not report_steps else core.execution_step(
614
            '%s:body' % self.name, report_steps + [self._step])
615
        # Enclose with instance-level (thread-local) setup nets
616
        step_with_instance_setup = add_setup_steps(
617
            body, instance_init_nets, instance_exit_nets,
618
            self.name + ':instance')
619
        # Set up runtime concurrent instances
620
        if self._num_instances and self._num_instances > 1:
621
            step_with_instance_setup.SetCreateWorkspace(True)
622
            step_with_instance_setup = core.execution_step(
623
                '%s:parallel',
624
                [step_with_instance_setup],
625
                num_concurrent_instances=self._num_instances)
626
        # Enclose with task-level setup nets
627
        self._step_with_setup = add_setup_steps(
628
            step_with_instance_setup, task_init_nets, task_exit_nets, self.name)
629

630
        return self._step_with_setup
631

632
    def output_list(self):
633
        return TaskOutputList(self._outputs)
634

635
    def outputs(self):
636
        return self._outputs
637

638
    def _notify_used(self):
639
        self.get_step()
640
        self._already_used = True
641

642
    def __repr__(self):
643
        return "Task(name={}, node={}, outputs={})".format(
644
            self.name, self.node, self.outputs())
645

646

647
class SetupNets:
648
    """
649
    Allow to register a list of nets to be run at initialization
650
    and finalization of Tasks or TaskGroups.
651
    For example, let's say you have the following:
652

653
        init_net = core.Net('init')
654
        my_val = init_net.ConstantFill([], 'my_val', value=0)
655

656
        net = core.Net('counter')
657
        net.Add([my_val, net.Const(1),], [my_val])
658

659
        with TaskGroup() as task_group:
660
            with Node('trainer'):
661
                my_task = Task(step=[net])
662

663
    In order to have `init_net` run once before `net` runs for the
664
    first time, you can do one of the following:
665

666
        net.add_attribute(Task.TASK_SETUP, SetupNets([init_net]))
667

668
    or
669

670
        net.add_attribute(TaskGroup.LOCAL_SETUP, SetupNets([init_net]))
671

672
    - With Task.TASK_SETUP, init_net will run once at my_task startup.
673
    - With TaskGroup.LOCAL_SETUP, init_net will run once on node 'trainer',
674
      before any task of the task group is run on that node.
675

676
    The same SetupNets object can be added to multiple nets. It will only
677
    run once per Task/TaskGroup run.
678
    """
679

680
    def __init__(self, init_nets=None, exit_nets=None):
681
        self.init_nets = init_nets
682
        self.exit_nets = exit_nets
683

684
    def setup(self, init_net):
685
        return self.init_nets
686

687
    def exit(self, exit_net):
688
        return self.exit_nets
689

690
    def __repr__(self):
691
        return "SetupNets(init_nets={}, exit_nets={})".format(
692
            self.init_nets, self.exit_nets)
693

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.