pytorch

Форк
0
3070 строк · 116.3 Кб
1
## @package core
2
# Module caffe2.python.core
3

4

5

6

7

8
from collections import namedtuple, OrderedDict, defaultdict
9
from past.builtins import basestring
10
from itertools import chain
11
from typing import Dict, Set
12

13
from caffe2.proto import caffe2_pb2
14
from caffe2.python import scope, utils, workspace
15
from caffe2.python.lazy import TriggerLazyImport
16
from caffe2.python.control_ops_grad import \
17
    gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output
18

19
import caffe2.python._import_c_extension as C
20

21
import copy
22
import pickle
23
import numpy as np
24
import sys
25
import traceback
26
import os
27

28
# Mac os specific message
29
if (sys.platform == 'darwin' and 'leveldb' in C.registered_dbs()):
30
    print('If you are using homebrew leveldb on a Mac OS, you might see an '
31
          'error warning you that malloc_zone_unregister() failed. This is '
32
          'not a caffe2 issue but is due to the homebrew leveldb having an '
33
          'incompatible memory allocator. It does not affect usage.')
34

35
# Convenience redirections to functions inside scope.
36
DeviceScope = scope.DeviceScope
37
NameScope = scope.NameScope
38

39

40
# Bring datatype enums to the main namespace
41
class DataType:
42
    UNDEFINED = 0
43
    FLOAT = 1
44
    INT32 = 2
45
    BYTE = 3
46
    STRING = 4
47
    BOOL = 5
48
    UINT8 = 6
49
    INT8 = 7
50
    UINT16 = 8
51
    INT16 = 9
52
    INT64 = 10
53
    FLOAT16 = 12
54
    DOUBLE = 13
55
    ZERO_COLLISION_HASH = 14
56
    REBATCHING_BUFFER = 15
57

58

59
def _CheckDataType():
60
    # Verify that the DataType values defined above match the ones defined in
61
    # the caffe2.proto file
62
    for name, value in caffe2_pb2.TensorProto.DataType.items():
63
        py_value = getattr(DataType, name, None)
64
        if py_value != value:
65
            raise AssertionError(
66
                f"DataType {name} does not match the value defined in "
67
                f"caffe2.proto: {py_value} vs {value}"
68
            )
69

70

71
_CheckDataType()
72

73

74
def _GetRegisteredOperators():
75
    return set(workspace.RegisteredOperators())
76

77

78
_REGISTERED_OPERATORS = _GetRegisteredOperators()
79

80

81
def RefreshRegisteredOperators(trigger_lazy=True):
82
    if trigger_lazy:
83
        TriggerLazyImport()
84
    global _REGISTERED_OPERATORS
85
    _REGISTERED_OPERATORS = _GetRegisteredOperators()
86

87

88
_GLOBAL_INIT_ARGS = []
89

90

91
def GlobalInit(args):
92
    TriggerLazyImport()
93
    _GLOBAL_INIT_ARGS.extend(args[1:])
94
    C.global_init(args)
95

96

97
def GetGlobalInitArgs():
98
    return _GLOBAL_INIT_ARGS[:]
99

100

101
def IsOperator(op_type):
102
    return IsOperatorWithEngine(op_type, engine='DEFAULT')
103

104

105
def IsOperatorWithEngine(op_type, engine):
106
    TriggerLazyImport()
107
    return C.op_registry_key(op_type, engine) in _REGISTERED_OPERATORS
108

109

110
def IsGPUDeviceType(device_type):
111
    return device_type in {caffe2_pb2.CUDA, caffe2_pb2.HIP}
112

113

114
def DeviceOption(
115
    device_type,
116
    device_id=0,
117
    random_seed=None,
118
    node_name=None,
119
    numa_node_id=None,
120
    extra_info=None,
121
):
122
    option = caffe2_pb2.DeviceOption()
123
    option.device_type = device_type
124
    option.device_id = device_id
125
    if node_name is not None:
126
        option.node_name = node_name
127
    if random_seed is not None:
128
        option.random_seed = random_seed
129
    if numa_node_id is not None:
130
        assert device_type == caffe2_pb2.CPU
131
        option.numa_node_id = numa_node_id
132
    if extra_info is not None:
133
        option.extra_info.extend(extra_info)
134
    return option
135

136

137
def device_option_equal(opt1, opt2, ignore_node_name=True, ignore_random_seed=True):
138
    if not opt1 or not opt2:
139
        return opt1 == opt2
140
    if not ignore_node_name and opt1.node_name != opt2.node_name:
141
        return False
142
    if not ignore_random_seed and opt1.random_seed != opt2.random_seed:
143
        return False
144
    if not opt1.device_type or not opt2.device_type:
145
        # At least one option is for CPU, check if both are for CPU.
146
        return not opt1.device_type and not opt2.device_type
147
    return opt1.device_id == opt2.device_id
148

149

150
def InferBlobDevices(net):
151
    '''
152
    Compute mapping from parameters to devices by looking at the
153
    device option of the op that creates the blob has
154
    '''
155
    mapping = {}
156
    for op in net.Proto().op:
157
        op_device = op.device_option
158
        if op_device is None:
159
            op_device = caffe2_pb2.DeviceOption(caffe2_pb2.CPU)
160
        # TODO: T18892922, use device annotations
161
        for b in op.output:
162
            mapping[b] = op_device
163
    return mapping
164

165

166
def InferOpBlobDevicesAsDict(op):
167
    input_dev_list, output_dev_list = InferOpBlobDevices(op)
168
    input_dict = {
169
        op.input[i]: input_dev_list[i]
170
        for i in range(len(op.input))
171
    }
172
    output_dict = {
173
        op.output[i]: output_dev_list[i]
174
        for i in range(len(op.output))
175
    }
176
    return input_dict, output_dict
177

178

179
def InferOpBlobDevices(op):
180
    device_info = C.infer_op_input_output_device(op.SerializeToString())
181
    input_info = []
182
    output_info = []
183
    for dev_str in device_info[0]:
184
        device_option = caffe2_pb2.DeviceOption()
185
        device_option.ParseFromString(dev_str)
186
        input_info.append(device_option)
187
    for dev_str in device_info[1]:
188
        device_option = caffe2_pb2.DeviceOption()
189
        device_option.ParseFromString(dev_str)
190
        output_info.append(device_option)
191
    return input_info, output_info
192

193

194
def InferOpDeviceAsBlobDevices(op):
195
    op_dev = op.device_option if op.device_option else caffe2_pb2.DeviceOption()
196
    input_dev = [op_dev] * len(op.input)
197
    output_dev = [op_dev] * len(op.output)
198
    return input_dev, output_dev
199

200

201
GradientSlice = namedtuple('GradientSlice', ['indices', 'values'])
202

203

204
class BlobReference:
205
    """A wrapper around a blob in a net.
206

207
    BlobReference gives us a way to refer to the network that the blob is
208
    generated from. Note that blobs are, essentially, just strings in the
209
    current workspace.
210
    """
211

212
    def __init__(self, name, net=None):
213
        """Initializes a blob reference.
214

215
        Note that this does not prepends the namescope. If needed, use
216
        ScopedBlobReference() to prepend the existing namespace.
217
        """
218
        if isinstance(name, str):
219
            self._name = name
220
        elif isinstance(name, bytes):
221
            self._name = name.decode('utf-8')
222
        else:
223
            self._name = str(name)
224
        self._from_net = net
225
        # meta allows helper functions to put whatever metainformation needed
226
        # there.
227
        self.meta = {}
228

229
    def __hash__(self):
230
        return hash(self._name)
231

232
    def __eq__(self, other):
233
        if isinstance(other, str):
234
            return self._name == other
235
        elif isinstance(other, bytes):
236
            return self._name == other.decode('utf-8')
237
        elif isinstance(other, BlobReference):
238
            return self._name == other._name
239
        else:
240
            return False
241

242
    def __ne__(self, other):
243
        return not(self == other)
244

245
    def __str__(self):
246
        return self._name
247

248
    def __repr__(self):
249
        return 'BlobReference("{}")'.format(self._name)
250

251
    def __add__(self, other):
252
        if not isinstance(other, str):
253
            raise RuntimeError('Cannot add BlobReference to a non-string.')
254
        return BlobReference(self._name + other, self._from_net)
255

256
    def __radd__(self, other):
257
        if not isinstance(other, str):
258
            raise RuntimeError('Cannot add a non-string to BlobReference.')
259
        return BlobReference(other + self._name, self._from_net)
260

261
    def Net(self):
262
        return self._from_net
263

264
    def GetNameScope(self):
265
        return self._name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]
266

267
    def GetUnscopedName(self):
268
        return self._name[self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]
269

270
    def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
271
        """Internal function that routes the operator generation to the
272
        network's __getattr__ function.
273
        """
274
        inputs = [] if inputs is None else inputs
275
        if isinstance(inputs, BlobReference) or isinstance(inputs, str):
276
            inputs = [inputs]
277
        # add self to the input list.
278
        inputs.insert(0, self)
279
        return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs)
280

281
    def __getattr__(self, op_type):
282
        """A wrapper allowing one to initiate operators from a blob reference.
283

284
        Example: for a blob reference b that comes from network n, doing
285
            b.Relu(...)
286
        is equivalent to doing
287
            net.Relu([b], ...)
288
        """
289
        if op_type.startswith('__'):
290
            raise AttributeError('Attribute {} not found.'.format(op_type))
291
        if self._from_net is None:
292
            raise AttributeError(
293
                'You cannot use a blob reference that does not have a net '
294
                'source to create operators. Create the operator from an '
295
                'explicit net object.')
296
        if not IsOperator(op_type):
297
            raise AttributeError(
298
                'Method ' + op_type + ' is not a registered operator.' +
299
                ' Did you mean: [' +
300
                ",".join(workspace.C.nearby_opnames(op_type)) + ']'
301
            )
302
        return lambda *args, **kwargs: self._CreateAndAddToNet(
303
            op_type, *args, **kwargs)
304

305
    def __dir__(self):
306
        TriggerLazyImport()
307
        additional_methods = [
308
            op
309
            for op in _REGISTERED_OPERATORS
310
            if '_ENGINE_' not in op or '_ENGINE_CUDNN' in op]
311
        return sorted(set(chain(
312
            dir(type(self)),
313
            self.__dict__.keys(),
314
            additional_methods
315
        )))
316

317

318
def ScopedName(name):
319
    """prefix the name with the current scope."""
320
    if isinstance(name, bytes):
321
        name = name.decode('ascii')
322
    return scope.CurrentNameScope() + name
323

324

325
def ScopedBlobReference(name, *args, **kwargs):
326
    """Returns a blob reference with scope prefixed."""
327
    return BlobReference(ScopedName(name), *args, **kwargs)
328

329

330
def _RectifyInputOutput(blobs, net=None):
331
    """A helper function to rectify the input or output of the CreateOperator
332
    interface.
333
    """
334
    if isinstance(blobs, (bytes, str)):
335
        # If blobs is a single string, prepend scope.CurrentNameScope()
336
        # and put it as a list.
337
        # TODO(jiayq): enforce using BlobReference instead of raw strings.
338
        return [ScopedBlobReference(blobs, net=net)]
339
    elif type(blobs) is BlobReference:
340
        # If blob is a BlobReference, simply put it as a list.
341
        return [blobs]
342
    elif type(blobs) in (list, tuple):
343
        # If blob is a list, we go through it and type check.
344
        rectified = []
345
        for blob in blobs:
346
            if isinstance(blob, (bytes, str)):
347
                rectified.append(ScopedBlobReference(blob, net=net))
348
            elif type(blob) is BlobReference:
349
                rectified.append(blob)
350
            else:
351
                raise TypeError(
352
                    "I/O blob #{} of unsupported type: {} of type {}"
353
                    .format(len(rectified), str(blob), type(blob)))
354
        return rectified
355
    else:
356
        raise TypeError(
357
            "Unknown input/output type: %s of type %s." %
358
            (str(blobs), type(blobs))
359
        )
360

361

362
def CreateOperator(
363
    operator_type,
364
    inputs,
365
    outputs,
366
    name='',
367
    control_input=None,
368
    device_option=None,
369
    arg=None,
370
    engine=None,
371
    debug_info=None,
372
    **kwargs
373
):
374
    """A function wrapper that allows one to create operators based on the
375
    operator type. The type should be a string corresponding to an operator
376
    registered with Caffe2.
377
    """
378
    operator = caffe2_pb2.OperatorDef()
379
    if (os.environ.get('CAFFE2_DEBUG')):
380
        stack = traceback.format_stack()
381
        operator.debug_info = "".join(stack[:-1])
382

383
    operator.type = operator_type
384
    operator.name = name
385
    # Add rectified inputs and outputs
386
    inputs = _RectifyInputOutput(inputs)
387
    outputs = _RectifyInputOutput(outputs)
388
    operator.input.extend(map(str, inputs))
389
    operator.output.extend(map(str, outputs))
390
    if control_input:
391
        control_input = _RectifyInputOutput(control_input)
392
        operator.control_input.extend(map(str, control_input))
393
    # Set device option:
394
    # (1) If device_option is explicitly set, use device_option.
395
    # (2) If not, but scope.CurrentDeviceScope() is set,
396
    #     then we use scope.CurrentDeviceScope().
397
    # (3) Otherwise, do not set device option.
398
    if device_option is not None:
399
        operator.device_option.CopyFrom(device_option)
400
    elif scope.CurrentDeviceScope() is not None:
401
        operator.device_option.CopyFrom(scope.CurrentDeviceScope())
402
    if engine is not None:
403
        operator.engine = engine
404
    if debug_info is not None:
405
        operator.debug_info = debug_info
406
    # random seed is defined in the device option, so we need to do special
407
    # care.
408

409
    if 'random_seed' in kwargs:
410
        operator.device_option.random_seed = kwargs['random_seed']
411
        del kwargs['random_seed']
412
    # Add given arguments that do not need parsing
413
    if arg is not None:
414
        operator.arg.extend(arg)
415
    # Add all other arguments
416
    for key, value in kwargs.items():
417
        if value is not None:
418
            operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
419

420
    if workspace.IsImmediate():
421
        workspace.RunOperatorImmediate(operator)
422
    return operator
423

424

425
def _RegisterPythonImpl(
426
    f, grad_f=None, python_func_type=None, pass_workspace=False
427
):
428
    if python_func_type:
429
        func = python_func_type(f)
430
        f = func.forward
431
        grad_f = func.backward
432
    else:
433
        if isinstance(f, tuple):
434
            f = f[0](*f[1], **f[2])
435
        if isinstance(grad_f, tuple):
436
            grad_f = grad_f[0](*grad_f[1], **grad_f[2])
437

438
    token = C.register_python_op(f, pass_workspace, '')
439
    if grad_f:
440
        C.register_python_gradient_op(token, grad_f)
441
    return token
442

443

444
def CreatePythonOperator(
445
    f, inputs,
446
    outputs,
447
    grad_f=None,
448
    pass_workspace=False,
449
    python_func_type=None,
450
    *args,
451
    **kwargs
452
):
453
    """
454
    `f` should have a signature (inputs, outputs)
455

456
    If `pass_workspace` is True, the signature is changed to
457
    (inputs, outputs, workspace) where `workspace` is the workspace the op
458
    is going to run on. This is potentially dangerous (as the op can manipulate
459
    the workspace directly), use on your own risk.
460
    """
461
    kwargs["token"] = _RegisterPythonImpl(
462
        f, grad_f, python_func_type, pass_workspace=pass_workspace
463
    )
464
    return CreateOperator("Python", inputs, outputs, *args, **kwargs)
465

466

467
def GetIndexFromGradientList(g_list, name):
468
    """A helper function to get the index from a gradient list, None if not
469
    matching."""
470
    for i, g in enumerate(g_list):
471
        if g == name:
472
            return i
473
        elif type(g) is GradientSlice:
474
            if (g.indices == name or g.values == name):
475
                return i
476
    return None
477

478

479
OpSSA = namedtuple('OpSSA', ['op', 'in_versions', 'out_versions'])
480
GradGenMeta = namedtuple('GradGenMeta',
481
                         ['grad_op', 'idx', 'gradient', 'device_option'])
482
SparseGradGenMeta = namedtuple('SparseGradGenMeta', [
483
    'grad_op_indices', 'idx_indices',
484
    'grad_op_values', 'idx_values',
485
    'gradient', 'device_option',
486
])
487

488

489
class IR:
490
    """A simple IR class to keep track of all intermediate representations used
491
    in the gradient computation.
492
    """
493

494
    def __init__(self, operators):
495
        # The IR class holds multiple metadata from the forward pass:
496
        # a) ssa: a list of [op, in_versions, out_versions] recording the
497
        #    input and the output version of each operator, similar
498
        #    to a normal SSA form.
499
        # b) input_usages: a dictionary specifying for each blob and
500
        #    each of its version, how many times it is used as input for another
501
        #    op.
502
        # c) frontier: maintaining the current versions of the blobs
503
        #    we are having in the workspace, after the execution of all the ops
504
        #    added to the IR so far. This is useful because if a gradient is
505
        #    trying to access an earlier version of a blob, we can sanity check
506
        #    that it is no longer there, and thus throw an error.
507
        # d) gradient_frontier: maps the names of blobs to its version that the
508
        #    gradient corresponds to.
509
        # e) gradient_generators: for each blob and each of its version, maps to
510
        #    a list of operators that generates its gradient together with the
511
        #    gradient name.
512
        self.ssa = []
513
        self.input_usages = defaultdict(lambda: defaultdict(list))
514
        self.frontier = defaultdict(int)
515
        self.gradient_frontier = {}
516
        self.gradient_generators = defaultdict(lambda: defaultdict(list))
517
        self.out_version_history = defaultdict(list)
518
        self.in_version_history = defaultdict(list)
519

520
        for op in operators:
521
            self.Play(op)
522

523
        self.SanityCheck(operators)
524

525
    def SanityCheck(self, operators):
526
        # Validate StopGradient usage by checking that StopGradient's output
527
        # is actually passed forward
528
        for op in operators:
529
            if op.type == 'StopGradient':
530
                if op.output[0] not in self.input_usages:
531
                    raise ValueError("""StopGradient's output '{}' is orphan.
532
You typically want to specify same input and output for
533
StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
534

535
    def Play(self, op):
536
        """"Adds an op to the current IR, and update the internal states to
537
        reflect the blobs and versions after the execution of the op.
538
        """
539
        # For input, they are the current version in the dict.
540
        in_versions = {}
541
        for s in op.input:
542
            in_versions[s] = self.frontier[s]
543
            self.input_usages[s][self.frontier[s]].append(len(self.ssa))
544
            self.in_version_history[s].append((op, self.frontier[s]))
545
        # For output, they are the current version plus one. If this is a
546
        # newly created blob, its version starts with zero.
547
        out_versions = {}
548
        for s in op.output:
549
            if s in self.frontier:
550
                self.frontier[s] += 1
551
            out_versions[s] = self.frontier[s]
552
            self.out_version_history[s].append((op, self.frontier[s]))
553
        # Add to SSA for bookkeeping.
554
        self.ssa.append(OpSSA(op, in_versions, out_versions))
555

556
    def CheckGradientOperatorInput(
557
            self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs):
558
        """Checks if the gradient operators can be correctly carried out."""
559
        forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
560
        original_index = GetIndexFromGradientList(g_output, grad_op_input)
561

562
        # Functions to generate debug help for version-mismatches
563
        def versionMismatchInfoOut(name):
564
            s = "DEBUG HELP:\n"
565
            s += "Maybe you use same output blob twice for different ops?\n"
566
            s += "== Version history of blob [{}]\n".format(name)
567
            for (op, vers) in self.out_version_history[name]:
568
                s += "Version (out) {} <-- {}".format(vers, op)
569
                s += "\n"
570
            return s
571

572
        def versionMismatchInfoIn(name):
573
            s = "DEBUG HELP:\n"
574
            s += "Maybe the blob was overwritten by another op?\n"
575
            s += "== Version history of blob [{}]\n".format(name)
576
            for (op, vers) in self.in_version_history[name]:
577
                s += "version (in) {} <-- {}".format(vers, op)
578
                s += "\n"
579
            return s
580

581
        # If it is a dense or sparse gradient name, it should match the
582
        # version of the corresponding output.
583
        if original_index is not None:
584
            original_name = forward_op.output[original_index]
585
            if (out_versions[original_name] !=
586
                    self.gradient_frontier[original_name]):
587
                raise RuntimeError(
588
                    'Gradient name "%s" is expected to correspond '
589
                    'to version %d of "%s", but currently we have '
590
                    'version %d.\n\n' % (
591
                        grad_op_input, out_versions[original_name],
592
                        original_name,
593
                        self.gradient_frontier[original_name]) +
594
                    versionMismatchInfoOut(original_name))
595
        # If it is an output name, the current version should match the
596
        # version when the operator was run.
597
        elif grad_op_input in out_versions:
598
            if self.frontier[grad_op_input] != out_versions[grad_op_input]:
599
                raise RuntimeError(
600
                    'Gradient operator needs output "%s" at version'
601
                    ' %d, but currently we have version %d.\n\n' % (
602
                        grad_op_input, out_versions[grad_op_input],
603
                        self.frontier[grad_op_input]
604
                    ) + versionMismatchInfoOut(grad_op_input)
605
                )
606
        # If it is an input name, the current version should match the
607
        # version when the operator was run.
608
        elif grad_op_input in in_versions:
609
            if (self.frontier[grad_op_input] != in_versions[grad_op_input]):
610
                raise RuntimeError(
611
                    'Gradient operator needs input "%s" at version '
612
                    '%d, but currently we have version %d.\n\n' % (
613
                        grad_op_input, in_versions[grad_op_input],
614
                        self.frontier[grad_op_input]
615
                    ) + versionMismatchInfoIn(grad_op_input)
616
                )
617
        # If it is none of the above, it should be a blob that is
618
        # generated locally by one of the previous gradient operators.
619
        else:
620
            if grad_op_input not in locally_generated_blobs:
621
                raise RuntimeError(
622
                    'Blob name "%s" not in the scope of operator: '
623
                    '%s\nand is not generated by any of the local '
624
                    'gradient operators.' % (grad_op_input, str(forward_op))
625
                )
626

627
    def AppendSparseGenerators(self, sparse_generators):
628
        # merge indices and values generators for sparse gradients
629
        for name, input_generators in sparse_generators.items():
630
            for version, generators in input_generators.items():
631
                if len(generators) == 1:
632
                    # either indices or values are generated (but not both)
633
                    generator = generators[0]
634
                else:
635
                    # both indices and values are generated
636
                    assert(len(generators) == 2)
637
                    op1_i, idx1_i, op1_v, idx1_v, g1, dev_1 = generators[0]
638
                    op2_i, idx2_i, op2_v, idx2_v, g2, dev_2 = generators[1]
639
                    assert(g1 == g2)
640
                    assert dev_1 == dev_2, (
641
                        "Unequal devices for sparse generators: "
642
                        "{} and {}".format(dev_1, dev_2)
643
                    )
644
                    assert(op1_i is None or op2_i is None)
645
                    assert(op1_v is None or op2_v is None)
646
                    assert(idx1_i == 0 or idx2_i == 0)
647
                    assert(idx1_v == 0 or idx2_v == 0)
648
                    generator = SparseGradGenMeta(
649
                        op1_i or op2_i, idx1_i + idx2_i,
650
                        op1_v or op2_v, idx1_v + idx2_v,
651
                        g1, dev_1)
652
                self.gradient_generators[name][version].append(generator)
653

654
    def BuildGradientGenerators(  # NOQA
655
            self, fwd_op_idx, gradient_ops, g_output, g_input):
656
        """Updates gradient_generators and gradient_frontier"""
657
        forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
658
        locally_generated_blobs = []
659
        sparse_generators = defaultdict(lambda: defaultdict(list))
660

661
        for grad_op in gradient_ops:
662
            # (1) check that inputs are valid
663
            for s in grad_op.input:
664
                self.CheckGradientOperatorInput(
665
                    s, g_output, fwd_op_idx, locally_generated_blobs)
666

667
            # (2) add outputs to the locally generated blobs
668
            # If an output corresponds to the gradient of an input, we also
669
            # record it to gradient_generators
670
            locally_generated_blobs.extend(map(str, grad_op.output))
671
            for i, output in enumerate(grad_op.output):
672
                input_index = GetIndexFromGradientList(g_input, output)
673
                if input_index is not None:
674
                    input_name = forward_op.input[input_index]
675
                    input_version = in_versions[input_name]
676
                    g = g_input[input_index]
677
                    if type(g) is GradientSlice:
678
                        # the output corresponds either to the indices or the
679
                        # values of the sparse gradient. In either case we
680
                        # create a (partial) SparseGradGenMeta. If necessary,
681
                        # we'll merge indices and values generators
682
                        # corresponding to the same gradient in step (3)
683
                        if g.indices == output:
684
                            m = SparseGradGenMeta(
685
                                grad_op, i, None, 0, g, grad_op.device_option)
686
                        else:
687
                            assert(g.values == output)
688
                            m = SparseGradGenMeta(
689
                                None, 0, grad_op, i, g, grad_op.device_option)
690
                        sparse_generators[input_name][input_version].append(m)
691
                    else:
692
                        self.gradient_generators[input_name][input_version] \
693
                            .append(GradGenMeta(
694
                                grad_op, i, g, grad_op.device_option))
695

696
        # (3) merge indices and values generators for sparse gradients, and
697
        # add them to gradient_generators
698
        self.AppendSparseGenerators(sparse_generators)
699

700
        # (4) for ops (e.g., Add, Sum, Sub) which have gradient outputs directly
701
        # passed from inputs (not computed from gradient ops), we create an
702
        # GradGenMeta with None grad_op and idx so that the gradient_generators
703
        # knows where the gradients are coming from. This is needed for creating
704
        # Sum op to accumulate the gradients from multiple parents.
705
        for input_index, g in enumerate(g_input):
706
            input_name = forward_op.input[input_index]
707
            input_version = in_versions[input_name]
708
            if not g:
709
                continue
710
            if type(g) is GradientSlice:
711
                if str(g.indices) not in locally_generated_blobs and \
712
                        str(g.values) not in locally_generated_blobs:
713
                    self.gradient_generators[input_name][input_version].append(
714
                        SparseGradGenMeta(None, 0, None, 0, g, forward_op.device_option))
715
            else:
716
                if str(g) not in locally_generated_blobs:
717
                    self.gradient_generators[input_name][input_version].append(
718
                        GradGenMeta(None, 0, g, forward_op.device_option))
719

720
        # Finally, for the gradients specified in g_input, we update the
721
        # gradient frontier to reflect the input versions that the gradients
722
        # correspond to.
723
        for i, g in enumerate(g_input):
724
            if g is not None:
725
                input_name = forward_op.input[i]
726
                input_version = in_versions[input_name]
727
                self.gradient_frontier[input_name] = input_version
728

729
    def _GetSumOpOutputName(self, generator, input_name):
730
        def remove_suffix(s, suffix):
731
            if s.endswith(suffix):
732
                return s[:-len(suffix)]
733
            return s
734

735
        for g in generator:
736
            if type(g) is GradGenMeta:
737
                grad_op, idx, _, _ = g
738
                if grad_op:
739
                    return grad_op.output[idx]
740
            else:
741
                assert(type(g) is SparseGradGenMeta)
742
                op_i, idx_i, op_v, idx_v, _, _ = g
743
                if op_i:
744
                    return remove_suffix(op_i.output[idx_i], '_indices')
745
                if op_v:
746
                    return remove_suffix(op_v.output[idx_v], '_values')
747

748
        return input_name + '_grad'
749

750
    IS_AUTO_GEN_SUM_OPS_TAG = "is_auto_gen_sum_ops"
751
    ONLY_KEEP_IS_AUTO_GEN_SUM_OPS_TAG = "only_keep_is_auto_gen_sum_ops_tag"
752

753
    def _SetSumOpsDeviceOption(self, sum_ops, generators):
754
        only_keep_is_auto_gen_sum_ops_tag = False
755
        for generator in generators:
756
            # we already checked that device options are consistent so we can just
757
            # break after finding the first clear_info request
758
            for extra_info in generator.device_option.extra_info:
759
                if extra_info == "{}:1".format(IR.ONLY_KEEP_IS_AUTO_GEN_SUM_OPS_TAG):
760
                    only_keep_is_auto_gen_sum_ops_tag = True
761
                    break
762

763
        if only_keep_is_auto_gen_sum_ops_tag:
764
            # if we find that device_option in the generator that
765
            # requires clear the extra info for the auto gen sum
766
            # Then we will try to clear them and only leave the
767
            # IS_AUTO_GEN_SUM_OPS_TAG
768
            for op in sum_ops:
769
                op.device_option.extra_info.extend([
770
                    "{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG)
771
                ])
772
        else:
773
            # we already checked that device options are consistent so we can just
774
            # use the first one we find
775
            for generator in generators:
776
                for op in sum_ops:
777
                    op.device_option.CopyFrom(generator.device_option)
778
                    op.device_option.extra_info.extend([
779
                        "{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG)
780
                    ])
781
                break
782

783
    def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
784
        new_grad_output = (
785
            '_' + grad_op.output[idx] + '_autosplit_{}'.format(cnt))
786
        if grad_op.type == "If":
787
            disambiguate_grad_if_op_output(grad_op, idx, new_grad_output)
788
        else:
789
            grad_op.output[idx] = new_grad_output
790
        return grad_op.output[idx], cnt + 1
791

792
    def _CheckSumOpsConflict(self, out_base_name, g):
793
        if str(out_base_name) == str(g):
794
            # TODO not sure what this message really means
795
            raise RuntimeError(
796
                'The gradient output of empty gradient op can not '
797
                'be the same as the normal name of the current '
798
                'input gradient.')
799

800
    def _MakeDenseSumOps(self, generators, out_base_name):
801
        sum_op_input = []
802
        cnt = 0
803

804
        assert len(generators) > 1
805

806
        first_grad_op = True
807
        for generator in generators:
808
            grad_op, idx, g, _ = generator
809
            assert(type(g) is not GradientSlice)
810
            if grad_op:
811
                if first_grad_op:
812
                    first_grad_op = False
813
                    out = grad_op.output[idx]
814
                else:
815
                    out, cnt = self._DisambiguateGradOpOutput(grad_op, idx, cnt)
816
                sum_op_input.append(out)
817
            else:
818
                self._CheckSumOpsConflict(out_base_name, g)
819
                sum_op_input.append(str(g))
820

821
        if out_base_name in sum_op_input:
822
            # Sum inplace mode works only for the first input
823
            # So we do a swap
824
            idx = sum_op_input.index(out_base_name)
825
            sum_op_input[0], sum_op_input[idx] = (
826
                sum_op_input[idx], sum_op_input[0]
827
            )
828
        sum_ops = [CreateOperator(
829
            "Sum",
830
            [BlobReference(x) for x in sum_op_input],
831
            BlobReference(out_base_name))]
832
        return sum_ops, out_base_name
833

834
    def _MakeSparseSumOps(self, generators, out_base_name):
835
        indices_concat_input = []
836
        values_concat_input = []
837
        cnt_i = 0
838
        cnt_v = 0
839

840
        for generator in generators:
841
            assert(type(generator) is SparseGradGenMeta)
842
            op_i, idx_i, op_v, idx_v, g, _ = generator
843
            if op_i:
844
                out, cnt_i = self._DisambiguateGradOpOutput(op_i, idx_i, cnt_i)
845
                indices_concat_input.append(out)
846
            else:
847
                self._CheckSumOpsConflict(out_base_name, g.indices)
848
                indices_concat_input.append(g.indices)
849
            if op_v:
850
                out, cnt_v = self._DisambiguateGradOpOutput(op_v, idx_v, cnt_v)
851
                values_concat_input.append(out)
852
            else:
853
                self._CheckSumOpsConflict(out_base_name, g.values)
854
                values_concat_input.append(g.values)
855

856
        indices_concat_output = out_base_name + '_indices_concat'
857
        indices_concat_split = out_base_name + '_indices_concat_split'
858
        values_concat_output = out_base_name + '_values_concat'
859
        values_concat_split = out_base_name + '_values_concat_split'
860
        # Sum the given sparse representations by simply concatenating the
861
        # indices (resp. values) tensors together. We don't do any deduplication
862
        # of indices at this point. This will be done as needed before the
863
        # optimizer is called
864
        sum_ops = [
865
            CreateOperator(
866
                "Concat",
867
                [BlobReference(x) for x in indices_concat_input],
868
                [BlobReference(x) for x in
869
                    [indices_concat_output, indices_concat_split]],
870
                axis=0
871
            ),
872
            CreateOperator(
873
                "Concat",
874
                [BlobReference(x) for x in values_concat_input],
875
                [BlobReference(x) for x in
876
                    [values_concat_output, values_concat_split]],
877
                axis=0
878
            ),
879
        ]
880
        sum_op_output = GradientSlice(
881
            indices=indices_concat_output,
882
            values=values_concat_output,
883
        )
884
        return sum_ops, sum_op_output
885

886
    def _MakeSumOps(self, input_name, input_version):
887
        generators = self.gradient_generators[input_name][input_version]
888
        out_base_name = self._GetSumOpOutputName(generators, input_name)
889
        types = list(set(type(x) for x in generators))
890
        assert(len(types) == 1)
891
        if types[0] is GradGenMeta:
892
            sum_ops, g = self._MakeDenseSumOps(generators, out_base_name)
893
        else:
894
            assert(types[0] is SparseGradGenMeta)
895
            sum_ops, g = self._MakeSparseSumOps(generators, out_base_name)
896
        self._SetSumOpsDeviceOption(sum_ops, generators)
897
        return sum_ops, g
898

899
    def _VerifyGradientGenerators(self, generator):
900
        # (1) check if all gradients are of the same type. Aggregating a mix of
901
        # sparse and dense gradients is not supported yet
902
        if len({type(g) for g in generator}) > 1:
903
            raise RuntimeError(
904
                'Automatic aggregation of a mix of sparse and dense gradients '
905
                'is not supported yet')
906

907
        # If for all the operators that used the operator, none or only one
908
        # produced the gradient, then no additional sum needs to be carried
909
        # out.
910
        if len(generator) < 2:
911
            return False
912

913
        all_gradient_names = []
914
        all_device_options = []
915
        for g in generator:
916
            if g.device_option:
917
                all_device_options.append(g.device_option)
918
            if type(g) is GradGenMeta:
919
                if g.grad_op:
920
                    all_gradient_names.append(g.gradient)
921
            else:
922
                assert(type(g) is SparseGradGenMeta)
923
                if g.gradient.values:
924
                    all_gradient_names.append(g.gradient.values)
925

926
        # Check if all grad op device options are the same.
927
        if len(all_device_options) >= 2 and not all(
928
                device_option_equal(d, all_device_options[0])
929
                for d in all_device_options[1:]):
930
            raise RuntimeError('Unexpected behavior: not all grad ops '
931
                               'have the same device option.')
932
        return True
933

934
    def DoGradientAccumulation(self, fwd_op_idx):
935
        """For each input name in the forward op, check if we will need to
936
        add gradient accumulation. If so, do gradient accumulation and return
937
        the list of gradient operators.
938

939
        The criteria for doing gradient accumulation is:
940
        (1) the specific input version has been used by multiple operators.
941
        (2) the current fwd_op_idx is the first to use that input, i.e. in the
942
            backward pass, is the last to optionally generate the gradient for
943
            the op.
944
        (3) For the operators that used the input, their gradient operators
945
            have generated more than 1 gradient.
946

947
        When accumulating operators, our current solution is to rename all the
948
        created gradients with an internal intermediate name, and then add a
949
        Sum() operator that adds up all the gradients. This may use more memory
950
        due to intermediate storage, but is usually the fastest approach as one
951
        can do one single sum for multiple intermediate gradients.
952
        """
953
        forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
954
        additional_sum_ops = []
955
        grad_map = {}
956
        for _i, input_name in enumerate(set(forward_op.input)):
957
            input_version = in_versions[input_name]
958
            input_usage = self.input_usages[input_name][input_version]
959
            if (len(input_usage) <= 1 or fwd_op_idx != input_usage[0]):
960
                # We do not need to do gradient accumulation yet.
961
                continue
962
            generator = self.gradient_generators[input_name][input_version]
963
            try:
964
                if not self._VerifyGradientGenerators(generator):
965
                    continue
966
            except RuntimeError as err:
967
                raise RuntimeError(
968
                    "Gradients for param ''{}'' failed to verify: {}".format(
969
                        input_name,
970
                        err
971
                    )
972
                ) from err
973

974
            # Finally, let's create the sum operator.
975
            sum_ops, g = self._MakeSumOps(input_name, input_version)
976
            additional_sum_ops.extend(sum_ops)
977
            grad_map[input_name] = g
978
        return additional_sum_ops, grad_map
979

980
    def _AppendAutoGradGenerator(self, y, grad, autograd_op):
981
        # Gradient here is not sparse  as it was generated by
982
        # a ConstantFill operator. Autogeneration for sparse gradients is
983
        # not supported
984
        generator = GradGenMeta(
985
            autograd_op, 0 if autograd_op else None, str(grad),
986
            autograd_op.device_option)
987

988
        self.gradient_generators[str(y)][self.frontier[str(y)]].append(
989
            generator)
990

991
    AUTOGEN_GRAD_SUFFIX = "_autogen_grad"
992

993
    def _GetInitGradients(self, ys):
994
        input_to_grad = {}
995
        gradient_ops = []
996

997
        for y, g in ys.items():
998
            autograd_op = None
999
            if g is None:
1000
                autograd_op = CreateOperator(
1001
                    "ConstantFill", [y], [str(y) + IR.AUTOGEN_GRAD_SUFFIX],
1002
                    value=1.0)
1003
                gradient_ops.append(autograd_op)
1004
                g = autograd_op.output[0]
1005
            # Since the C++ gradient registry does not have notion of
1006
            # NameScopes, we will convert all references to strings.
1007
            input_to_grad[str(y)] = (
1008
                GradientSlice(str(g[0]), str(g[1]))
1009
                if isinstance(g, GradientSlice) else str(g))
1010
            # Autogenerated gradients are assumed to be provided for the last
1011
            # input version
1012
            if autograd_op is not None:
1013
                self._AppendAutoGradGenerator(y, g, autograd_op)
1014

1015
        return input_to_grad, gradient_ops
1016

1017
    def _GenerateGradientsForForwardOp(
1018
            self, forward_op_idx, input_to_grad):
1019
        new_input_to_grad = {}
1020
        gradient_ops = []
1021
        forward_op, in_versions, out_versions = self.ssa[forward_op_idx]
1022
        g_output = list(
1023
            input_to_grad.get(name, None) for name in forward_op.output)
1024

1025
        if not all(g is None for g in g_output) or (
1026
                forward_op.type == "ZeroGradient"):
1027
            gradient_ops, g_input = GradientRegistry.GetGradientForOp(
1028
                forward_op, g_output)
1029
            # Check if the gradient operators are legal, and update
1030
            # gradient_generators and gradient_frontier
1031
            self.BuildGradientGenerators(
1032
                forward_op_idx, gradient_ops, g_output, g_input)
1033
            # Record the gradient map to all_input_to_grad.
1034
            for name, grad in zip(forward_op.input, g_input):
1035
                # Do not overwrite an existing gradient with a None
1036
                # unless the input is also an output of the op, since
1037
                # we update the blob version when blob is output of an
1038
                # operator.
1039
                if grad is not None or \
1040
                    name not in input_to_grad or \
1041
                        name in list(forward_op.output):
1042
                    new_input_to_grad[name] = grad
1043

1044
        return new_input_to_grad, gradient_ops
1045

1046
    def GetBackwardPass(self, ys):
1047
        """Gets the backward pass that computes the derivatives of given blobs.
1048

1049
        Inputs:
1050
          ys: a list or a dictionary specifying what blobs we want to compute
1051
              derivatives of. If the input is a list, we will automatically
1052
              generate their gradients with all-one values; if the input is a
1053
              dictionary, for any dictionary entries that are not None, we will
1054
              take the corresponding blobs as their gradients; for all those
1055
              that are None, we will auto-fill them with 1.
1056
        """
1057
        if isinstance(ys, list):
1058
            ys = dict((y, None) for y in ys)
1059
        elif not isinstance(ys, dict):
1060
            raise TypeError("ys should either be a list or a dict.")
1061

1062
        # Set the gradient frontier with the initialized external
1063
        # gradients.
1064
        for y in ys.keys():
1065
            self.gradient_frontier[y] = self.frontier[y]
1066
            self.input_usages[str(y)][self.frontier[str(y)]].append(
1067
                len(self.ssa))
1068

1069
        all_input_to_grad, all_gradient_ops = self._GetInitGradients(ys)
1070

1071
        # (2) Now, after having the virtual play above, we now play the ops
1072
        # backwards, creating the gradients along the path. Note that although
1073
        # we are playing it backwards, we cannot refer to variables that are
1074
        # at a version older than current_versions because it is already been
1075
        # overwritten.
1076
        for forward_op_idx in reversed(range(len(self.ssa))):
1077
            input_to_grad, gradient_ops = self._GenerateGradientsForForwardOp(
1078
                forward_op_idx, all_input_to_grad)
1079
            all_input_to_grad.update(input_to_grad)
1080
            all_gradient_ops += gradient_ops
1081

1082
            # If there are multiple use blobs, do gradient accumulation.
1083
            additional_sum_ops, grad_map = self.DoGradientAccumulation(
1084
                forward_op_idx)
1085
            # This line is so that if in an accumulation some of the operators
1086
            # have not produced gradients, they still do not overwrite the
1087
            # general all_input_to_grad map.
1088
            all_input_to_grad.update(grad_map)
1089
            all_gradient_ops += additional_sum_ops
1090

1091
        # (3) Post-processing.
1092
        # After we have done computation for each op, we now have the gradient
1093
        # operators ready. For the output map, we will convert everything to
1094
        # BlobReferences for easier handling in python.
1095
        all_input_to_grad_out = {}
1096
        for key, val in all_input_to_grad.items():
1097
            if val is not None:
1098
                if isinstance(val, (bytes, str)):
1099
                    grad_out = BlobReference(val)
1100
                else:
1101
                    grad_out = GradientSlice(BlobReference(val[0]),
1102
                                             BlobReference(val[1]))
1103
                all_input_to_grad_out[BlobReference(key)] = grad_out
1104
        return all_gradient_ops, all_input_to_grad_out
1105

1106

1107
class GradientRegistry:
1108
    """GradientRegistry holds the mapping from operators to their gradients."""
1109
    gradient_registry_ = {}
1110

1111
    @classmethod
1112
    def RegisterGradient(cls, op_type):
1113
        """A decorator for registering gradient mappings."""
1114

1115
        def Wrapper(func):
1116
            cls.gradient_registry_[op_type] = func
1117
            return func
1118

1119
        return Wrapper
1120

1121
    @classmethod
1122
    def _GetGradientForOpCC(cls, op_def, g_output):
1123
        # TODO(tulloch) - Propagate GradientWrapper up through the stack.
1124
        def from_untyped(grad):
1125
            if grad is None:
1126
                w = C.GradientWrapper()
1127
                assert w.is_empty()
1128
                return w
1129
            try:
1130
                (indices, values) = grad
1131
                w = C.GradientWrapper()
1132
                w.indices = indices
1133
                w.values = values
1134
                assert w.is_sparse()
1135
                return w
1136
            except ValueError:
1137
                w = C.GradientWrapper()
1138
                w.dense = grad
1139
                assert w.is_dense()
1140
                return w
1141

1142
        g_output = [from_untyped(grad) for grad in g_output]
1143
        grad_defs_str, g_input = C.get_gradient_defs(
1144
            op_def.SerializeToString(), g_output)
1145

1146
        def to_untyped(grad_wrapper):
1147
            if grad_wrapper.is_empty():
1148
                return None
1149
            if grad_wrapper.is_sparse():
1150
                return GradientSlice(grad_wrapper.indices, grad_wrapper.values)
1151
            assert grad_wrapper.is_dense()
1152
            return grad_wrapper.dense
1153

1154
        g_input = [to_untyped(grad_wrapper) for grad_wrapper in g_input]
1155
        grad_defs = []
1156
        for grad_def_str in grad_defs_str:
1157
            grad_def = caffe2_pb2.OperatorDef()
1158
            grad_def.ParseFromString(grad_def_str)
1159
            grad_defs.append(grad_def)
1160
        return grad_defs, g_input
1161

1162
    @classmethod
1163
    def GetGradientForOp(cls, op, g_output):
1164
        try:
1165
            gradient_ops, g_input = cls._GetGradientForOpCC(op, g_output)
1166
        except Exception as e:
1167
            # Not supported in C++; will try python registration next.
1168
            if op.type in cls.gradient_registry_:
1169
                gradient_ops, g_input = cls.gradient_registry_[op.type](
1170
                    op, g_output
1171
                )
1172
            else:
1173
                raise Exception(
1174
                    "Exception when creating gradient for [{}]:{}.\nOp: \n{}".
1175
                    format(op.type, e, str(op))
1176
                ) from e
1177

1178
        if gradient_ops is None:
1179
            return [], g_input
1180
        if type(gradient_ops) is not list:
1181
            gradient_ops = [gradient_ops]
1182
        return gradient_ops, g_input
1183

1184
    @classmethod
1185
    def GetBackwardPass(cls, operators, ys, ys_generate_gradient=False):
1186
        """Gets the backward pass for the list of operators.
1187

1188
        Args:
1189
            operators: a list of operators constituting the forward pass.
1190
            ys: a list or a dictionary specifying what blobs we want to compute
1191
                derivatives of. If the input is a list, we will automatically
1192
                generate their gradients with all-one values; if the input is a
1193
                dictionary, for any dictionary entries that are not None, we'll
1194
                take the corresponding blobs as their gradients; for all those
1195
                that are None, we will auto-fill them with 1.
1196
        Returns:
1197
            gradient_ops: a list of gradient operators to run.
1198
            all_input_to_grads: a map from input to their corresponding
1199
                gradients.
1200
        """
1201
        ir = IR(operators)
1202
        return ir.GetBackwardPass(ys)
1203

1204

1205
GradientRegistry.RegisterGradient('Do')(gen_do_gradient)
1206
GradientRegistry.RegisterGradient('If')(gen_if_gradient)
1207
GradientRegistry.RegisterGradient('While')(gen_while_gradient)
1208

1209

1210
def get_ssa(net, blob_versions=None):
1211
    """
1212
    Given a net, return a structure containing the version of each input and
1213
    output blob used by each operator.
1214

1215
    Args:
1216
        net:            either a Net or a NetDef
1217
        blob_versions:  (optional) map with current version number for given
1218
                        blob names. If not provided or blob not found, start
1219
                        from version 0.
1220
    Returns:
1221
        Tuple (ssa, blob_versions)
1222
        ssa:            list of tuples (versioned_inputs, versioned_outputs)
1223
                        for each op in the net. A versioned input is a tuple
1224
                        (blob_name, version).
1225
        blob_versions:  updated map with latest version of each blob found in
1226
                        the net.
1227
    """
1228
    proto = net.Proto() if isinstance(net, Net) else net
1229
    assert isinstance(proto, caffe2_pb2.NetDef)
1230
    if blob_versions is None:
1231
        blob_versions = {}
1232
    if isinstance(net, list):
1233
        return [get_ssa(n, blob_versions) for n in net], blob_versions
1234
    for i in proto.external_input:
1235
        if i not in blob_versions:
1236
            blob_versions[str(i)] = 0
1237
    ssa = []
1238
    for op in proto.op:
1239
        if not proto.external_input:
1240
            for i in op.input:
1241
                if i not in blob_versions:
1242
                    blob_versions[i] = 0
1243
        inputs = [(str(i), blob_versions.get(str(i), 0)) for i in op.input]
1244
        for o in op.output:
1245
            blob_versions[str(o)] = blob_versions.get(str(o), 0) + 1
1246
        outputs = [(str(o), blob_versions[str(o)]) for o in op.output]
1247
        ssa.append((inputs, outputs))
1248
    return ssa, blob_versions
1249

1250

1251
def get_undefined_blobs(ssa):
1252
    """
1253
    Given a ssa in the format produced by get_ssa(), return a set of blobs that
1254
    are used before they are defined, which corresponds to inputs at version 0.
1255
    """
1256
    undef_blobs = set()
1257
    for inputs, _outputs in ssa:
1258
        undef_blobs |= set(name for (name, ver) in inputs if ver == 0)
1259
    return undef_blobs
1260

1261

1262
def get_output_producers(ssa):
1263
    """
1264
    Given a ssa in the format produced by get_ssa(), returns a map from
1265
    versioned blob into the operator index that produces that version of
1266
    the blob. A versioned blob is a tuple (blob_name, version).
1267
    """
1268
    producers = {}
1269
    for i, (_inputs, outputs) in enumerate(ssa):
1270
        for o in outputs:
1271
            producers[o] = i
1272
    return producers
1273

1274

1275
def get_op_ids_in_path(ssa, blob_versions, inputs, outputs):
1276
    """
1277
    Given a ssa and blob_versions as produced by get_ssa(), returns the list
1278
    of op indices that are necessary in order to generate the blobs in
1279
    `outputs`, given blobs in `inputs`.
1280
    Consider that the `inputs` are given in their latest version.
1281
    """
1282
    inputs_set = set((str(i), blob_versions[str(i)]) for i in inputs)
1283
    producers = get_output_producers(ssa)
1284
    queue = [(str(o), blob_versions[str(o)]) for o in outputs]
1285
    used_op_ids = set()
1286
    while len(queue) > 0:
1287
        o = queue.pop()
1288
        if (o not in inputs_set) and (o in producers):
1289
            op_id = producers[o]
1290
            if op_id not in used_op_ids:
1291
                used_op_ids |= {op_id}
1292
                inputs, _ = ssa[op_id]
1293
                queue.extend(inputs)
1294
    return sorted(used_op_ids)
1295

1296

1297
def recurrent_network_op_remap(op, prefix, blob_remap):
1298
    """
1299
    Parameters
1300
    ----------
1301
    op : Caffe2 operator (RecurrentNetworkOp or RecurrentNetworkGradientOp).
1302
    prefix: this argument is not used in this function, just for legacy support.
1303
    blob_remap : Dictionary that represents the map from old blob name to new.
1304

1305
    Updates blob names in arguments of RecurrentNetworkOp and
1306
    RecurrentNetworkGradientOp to conform to cloned input and output of both
1307
    operators and also makes sure names of locally generated blobs in arguments
1308
    have the same prefix as the input and output of the operators.
1309
    """
1310

1311
    def get_remapped_str(blob_str):
1312
        if isinstance(blob_str, bytes):
1313
            blob_str = blob_str.decode('utf-8')
1314
        return blob_remap.get(blob_str, blob_str).encode('utf-8')
1315

1316
    for argument in op.arg:
1317
        if len(argument.strings) > 0:
1318
            for i in range(len(argument.strings)):
1319
                argument.strings[i] = get_remapped_str(argument.strings[i])
1320
        elif argument.name == 'timestep':
1321
            argument.s = get_remapped_str(argument.s)
1322
        elif argument.name.endswith('step_net'):
1323
            # argument is a proto
1324
            remap_proto(argument, blob_remap)
1325

1326

1327
def control_op_remap(op, prefix, blob_remap):
1328
    net_arg_names = []
1329
    if op.type == "If" or op.type == "AsyncIf":
1330
        net_arg_names = ['then_net', 'else_net']
1331
    else:
1332
        net_arg_names = ['loop_net', 'cond_net']
1333
    for argument in op.arg:
1334
        if argument.name in net_arg_names:
1335
            assert argument.n, \
1336
                "Expected non empty net in " + op.type + "'s " + argument.name + " argument"
1337
            subnet = Net(argument.n)
1338
            remapped_subnet = subnet.Clone(
1339
                name=(subnet._net.name if subnet._net.name else '') + '_remapped',
1340
                blob_remap=blob_remap)
1341
            argument.n.CopyFrom(remapped_subnet.Proto())
1342

1343

1344
DEFAULT_REMAP_FUNCS = {
1345
    'RecurrentNetwork': recurrent_network_op_remap,
1346
    'RecurrentNetworkGradient': recurrent_network_op_remap,
1347
    'If': control_op_remap,
1348
    'While': control_op_remap,
1349
    'AsyncIf': control_op_remap,
1350
}
1351

1352

1353
def remap_proto(argument, blob_remap):
1354
    subnet = Net(argument.n)
1355

1356
    cloned_sub_net = subnet.Clone(
1357
        'cloned_sub_net',
1358
        blob_remap,
1359
    )
1360

1361
    argument.n.CopyFrom(cloned_sub_net.Proto())
1362

1363

1364
def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None,
1365
                       keep_schema=True):
1366
    """
1367
    Clone the given Net, binding its input schema to the given `inputs` record.
1368
    Blob names defined by the net are prepended with the given `prefix`.
1369

1370
    Args:
1371
        net:        the net to clone
1372
        name:       the name of the new net
1373
        prefix:     the prefix to append to local blobs
1374
        blob_remap: (optional) dict with additional blob name remapping.
1375
        inputs:     (optional) input record that will provide actual input
1376
                    values for the cloned net. Must be compatible with the
1377
                    net's input schema or be a strict superset of it
1378
        keep_schema: by default (True), the original schema will be kept and
1379
                     remapped accordingly. otherwise, the schema will be set as
1380
                     inputs or left empty if inputs is not given.
1381
    Returns:
1382
        Tuple (cloned_net, blob_remap)
1383
        clone_net:  the cloned Net
1384
        blob_remap: a map from original blob names into remapped blob names
1385
    """
1386
    from caffe2.python import schema
1387
    assert isinstance(net, Net)
1388
    if blob_remap is None:
1389
        blob_remap = {}
1390
    if inputs is not None:
1391
        assert isinstance(inputs, schema.Field)
1392
        original = net.input_record()
1393
        assert original is not None
1394
        # TODO(azzolini): improve schema type checking
1395
        diff = set(original.field_names()) - set(inputs.field_names())
1396
        assert len(diff) == 0, (
1397
            "Schemas don't match, extra fields {diff} found in the net {name}. "
1398
            "original: {original}; inputs: {inputs}"
1399
            .format(
1400
                diff=diff, name=net.Name(), original=original.field_names(),
1401
                inputs=inputs.field_names()
1402
            )
1403
        )
1404
        original_mapping = dict(zip(original.field_names(),
1405
                                    original.field_blobs()))
1406
        for fn, fb in zip(inputs.field_names(), inputs.field_blobs()):
1407
            if fn in original_mapping:
1408
                blob_remap[str(original_mapping[fn])] = str(fb)
1409
    proto = net.Proto()
1410
    ssa, blob_versions = get_ssa(proto)
1411
    undef_blobs = get_undefined_blobs(ssa)
1412

1413
    for blob in blob_versions.keys():
1414
        if blob in blob_remap:
1415
            continue
1416
        elif blob in undef_blobs:
1417
            blob_remap[blob] = blob
1418
        else:
1419
            blob_remap[blob] = prefix + blob
1420
    cloned_net = net.Clone(name, blob_remap, keep_schema=keep_schema)
1421
    if not keep_schema and inputs:
1422
        cloned_net.set_input_record(inputs)
1423
    return cloned_net, blob_remap
1424

1425

1426
def _get_blob_ref(blob_name_or_ref):
1427
    return (
1428
        blob_name_or_ref if isinstance(input, BlobReference)
1429
        else BlobReference(blob_name_or_ref)
1430
    )
1431

1432

1433
def _recover_record_by_prefix(names, prefix=''):
1434
    """
1435
    Tries to recover record by taking a subset of blob names with
1436
    a given prefix name and interpreting them as schema column names
1437
    """
1438
    from caffe2.python import schema
1439
    column_names = [name[len(prefix):] for name in names
1440
                    if name.startswith(prefix)]
1441
    if not column_names:
1442
        return None
1443
    return schema.from_column_list(
1444
        column_names,
1445
        col_blobs=[_get_blob_ref(prefix + name) for name in column_names])
1446

1447

1448
class Net:
1449
    _net_names_used_counters: Dict[str, int] = {}
1450
    _net_names_used: Set[str] = set()
1451
    operator_registry_ = {}
1452

1453
    @staticmethod
1454
    def current_prefix():
1455
        from caffe2.python.net_builder import NetBuilder
1456
        builder = NetBuilder.current(required=False)
1457
        return builder.name if builder else ''
1458

1459
    @staticmethod
1460
    def _reset_used_names() -> None:
1461
        Net._net_names_used_counters = {}
1462
        Net._net_names_used = set()
1463

1464
    @staticmethod
1465
    def _get_next_net_name(basename):
1466
        basename = "/".join(x for x in [Net.current_prefix(), basename] if x)
1467
        idx = Net._net_names_used_counters.get(basename, 0)
1468
        while (
1469
            name := basename if idx == 0 else f"{basename}_{idx}"
1470
        ) in Net._net_names_used:
1471
            idx += 1
1472
        Net._net_names_used_counters[basename] = idx + 1
1473
        Net._net_names_used.add(name)
1474
        return name
1475

1476
    def __init__(self, name_or_proto, inplace=False):
1477
        """
1478
        Create a Net.
1479
        Args:
1480
            name_or_proto:  If a NetDef is provided, clone it (or take ownership,
1481
                            depending on the value of `inplace`). Otherwise,
1482
                            create an empty net with the given name.
1483
            inplace: If a NetDef is provided, take ownership when `inplace` is True;
1484
                     otherwise, clone it.
1485
        """
1486
        self._input_record = None
1487
        self._output_record = None
1488
        # Register blobs so that it's guaranteed that different calls to
1489
        # NextBlob/NextScopedBlob always return blobs with different names
1490
        self._registered_blob_names = set()
1491
        self._recreate_lookup_tables = False
1492
        self._op_outputs = set()
1493
        self._external_input_map = set()
1494
        self._attr_dict = defaultdict(list)
1495
        if type(name_or_proto) is caffe2_pb2.NetDef:
1496
            proto = name_or_proto
1497
            # We are initializing a network by a NetDef. In this case, we will
1498
            # initialize our network with the given netdef.
1499
            if inplace:
1500
                self._net = proto
1501
            else:
1502
                self._net = caffe2_pb2.NetDef()
1503
                self._net.CopyFrom(proto)
1504

1505
            existing_outputs = [list(op.output) for op in self._net.op]
1506

1507
            self._external_input_map.update(list(self._net.external_input))
1508

1509
            # Set the next name index properly.
1510
            existing_names = set()
1511
            for op in self._net.op:
1512
                existing_names.update(list(op.input))
1513
            for output in existing_outputs:
1514
                existing_names.update(output)
1515

1516
            for outs in existing_outputs:
1517
                self._op_outputs.update(outs)
1518

1519
            prefix_len = len(self._net.name + '_blob_')
1520
            autogen_indices = []
1521
            for s in existing_names:
1522
                if s.startswith(self._net.name + '_blob_'):
1523
                    try:
1524
                        autogen_indices.append(int(s[prefix_len]))
1525
                    except ValueError:
1526
                        pass
1527
            if len(autogen_indices):
1528
                self._next_name_index = max(autogen_indices) + 1
1529
            else:
1530
                self._next_name_index = 0
1531
            name = self._net.name
1532
        else:
1533
            name = name_or_proto
1534
            self._net = caffe2_pb2.NetDef()
1535
            self._next_name_index = 0
1536

1537
        # make sure that this net name hasn't been used before
1538
        self._net.name = Net._get_next_net_name(name)
1539

1540
        # a map between prefix and ID for fast generation of blob names
1541
        self._next_blob_name_ids = {}
1542

1543

1544
    def AppendNet(self, net, device_option=None):
1545
        assert isinstance(net, Net)
1546
        for i in net.Proto().external_input:
1547
            if (
1548
                i not in self.Proto().external_input and
1549
                i not in self._op_outputs
1550
            ):
1551
                self.Proto().external_input.append(i)
1552

1553
        self.Proto().external_output.extend(
1554
            [
1555
                o for o in net.Proto().external_output
1556
                if o not in self.Proto().external_output
1557
            ]
1558
        )
1559
        ops = net.Proto().op
1560
        if device_option is not None:
1561
            ops = [copy.deepcopy(op) for op in ops]
1562
            for op in ops:
1563
                op.device_option.CopyFrom(device_option)
1564
            for op in ops:
1565
                if op.type == "RecurrentNetwork":
1566
                    for arg in op.arg:
1567
                        if arg.name.endswith('step_net'):
1568
                            for step_op in arg.n.op:
1569
                                step_op.device_option.CopyFrom(device_option)
1570

1571
        self._ExtendOps(ops)
1572
        return self
1573

1574
    def LogInfo(self, *msg_or_blobs):
1575
        for msg_or_blob in msg_or_blobs:
1576
            if not isinstance(msg_or_blob, BlobReference):
1577
                blob = self.GivenTensorStringFill(
1578
                    [], self.NextName('log'),
1579
                    shape=[], values=[msg_or_blob])
1580
            else:
1581
                blob = msg_or_blob
1582
            self.Print(blob, [])
1583

1584
    def add_attribute(self, name, obj):
1585
        """
1586
        Add `obj` to the list of attributes in this net under the given `name`.
1587
        Attributes are user-defined objects and have no pre-defined semantics.
1588
        """
1589
        self._attr_dict[name].append(obj)
1590

1591
    def get_attributes(self, name):
1592
        """
1593
        Returns the list of attributes in this net for a given `name`.
1594
        Attributes are user-defined objects added with `add_attribute'.
1595
        """
1596
        return self._attr_dict.get(name, [])
1597

1598
    def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False):
1599
        """
1600
        Adds a random seed to each op in the net.
1601
        If sequence_seed is set, the i-th op has rand_seed=`seed + i`
1602
        If seed_on_op_def is set, the op rand_seed=hash(str(op))
1603
        sequence_seed and seed_on_op_def cannot be both set to True.
1604
        """
1605
        assert not (sequence_seed and seed_on_op_def), (
1606
            'sequence_seed and seed_on_op_def cannot be both set to True.')
1607
        for i, op in enumerate(self.Proto().op):
1608
            if sequence_seed:
1609
                curr_seed = seed + i
1610
            elif seed_on_op_def:
1611
                curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
1612
            else:
1613
                curr_seed = seed
1614
            op.device_option.random_seed = curr_seed
1615

1616
    def Name(self):
1617
        return self._net.name
1618

1619
    def __str__(self):
1620
        return self.Name()
1621

1622
    def Const(self, array, blob_out=None, dtype=None):
1623
        if isinstance(array, bool):
1624
            return self.ConstantFill(
1625
                [],
1626
                blob_out or 1,
1627
                dtype=DataType.BOOL,
1628
                value=array)
1629

1630
        if dtype is None:
1631
            array = np.array(array)
1632
        else:
1633
            array = np.array(array, dtype=dtype)
1634

1635
        def do_set(operator):
1636
            return operator(
1637
                [],
1638
                blob_out or 1,
1639
                shape=array.shape,
1640
                values=array.flatten().tolist())
1641

1642
        if array.dtype == np.int32:
1643
            return do_set(self.GivenTensorIntFill)
1644
        elif array.dtype == np.int64:
1645
            return do_set(self.GivenTensorInt64Fill)
1646
        elif array.dtype == str:
1647
            return do_set(self.GivenTensorStringFill)
1648
        elif array.dtype == bool:
1649
            return do_set(self.GivenTensorBoolFill)
1650
        else:
1651
            return do_set(self.GivenTensorFill)
1652

1653
    def BlobIsDefined(self, blob):
1654
        """
1655
        Returns true if the given BlobReference is produced as output of
1656
        an operator in this net, or if it is provided as an external input.
1657
        """
1658
        if self._recreate_lookup_tables:
1659
            self._RecreateLookupTables()
1660
        name = str(blob)
1661
        return (name in self._op_outputs) or (name in self._external_input_map)
1662

1663
    def UsesBlob(self, blob):
1664
        """
1665
        Returns true iff the given BlobReference is used by any operator
1666
        or this net, or if it is one of the external inputs of the net.
1667
        """
1668
        blob_name = str(blob)
1669
        for op in self._net.op:
1670
            for input in op.input:
1671
                if input == blob_name:
1672
                    return True
1673
        return blob_name in self._external_input_map
1674

1675
    def UsedBlobNames(self):
1676
        """
1677
        Returns a set of blob names used in the net
1678
        """
1679
        blob_names = set()
1680
        for op in self._net.op:
1681
            blob_names |= set(op.input)
1682
            blob_names |= set(op.output)
1683
        if self._net.external_input:
1684
            blob_names |= set(self._net.external_input)
1685
        if self._net.external_output:
1686
            blob_names |= set(self._net.external_output)
1687
        return blob_names
1688

1689
    def GetBlobRef(self, blob_name):
1690
        """
1691
        Given the name of a blob produced by this net, return a BlobReference
1692
        to it. If the blob is not produced by any op in this net,
1693
        raises KeyError.
1694
        """
1695
        blob_name = str(blob_name)
1696
        if not self.BlobIsDefined(blob_name):
1697
            raise KeyError('Net does not define blob %s' % blob_name)
1698
        return BlobReference(blob_name, self)
1699

1700
    def Clone(
1701
        self,
1702
        name,
1703
        blob_remap=None,
1704
        op_id_mask=None,
1705
        remap_funcs=None,
1706
        keep_schema=True,
1707
        update_external_list=False,
1708
    ):
1709
        """
1710
        Clone this net.
1711
        Args:
1712
            name:        name of the cloned net
1713
            blob_remap:  optional map with list of blob names to replace
1714
            op_id_mask:  optional list of operator indices to include in
1715
                         the cloned net. If not provided, all ops are included.
1716
        """
1717
        orig_remap_funcs = {} if remap_funcs is None else remap_funcs
1718
        # by default we want to put RecurrentNetworkOp and
1719
        # RecurrentNetworkGradientOp into remap_funcs, as these two operators
1720
        # also take blobs and proto into the arguments.
1721
        remap_funcs = DEFAULT_REMAP_FUNCS.copy()
1722
        remap_funcs.update(orig_remap_funcs)
1723
        proto = self._net
1724
        new_proto = caffe2_pb2.NetDef()
1725
        new_proto.CopyFrom(proto)
1726
        new_proto.name = name
1727

1728
        if blob_remap is None:
1729
            blob_remap = {}
1730
        if op_id_mask is None:
1731
            op_id_mask = list(range(0, len(proto.op)))
1732

1733
        def get_remapped_str(blob):
1734
            blob_str = str(blob)
1735
            return str(blob_remap.get(blob_str, blob_str))
1736

1737
        def remap_list(proto_list):
1738
            new_list = [get_remapped_str(b) for b in proto_list]
1739
            del proto_list[:]
1740
            proto_list.extend(new_list)
1741

1742
        def remap_op(op):
1743
            new_op = caffe2_pb2.OperatorDef()
1744
            new_op.CopyFrom(op)
1745
            remap_list(new_op.input)
1746
            remap_list(new_op.output)
1747
            if new_op.type in remap_funcs:
1748
                remap_funcs[new_op.type](
1749
                    new_op,
1750
                    (name + '/') if name else '',
1751
                    blob_remap,
1752
                )
1753
            return new_op
1754

1755
        del new_proto.op[:]
1756
        new_proto.op.extend([remap_op(proto.op[op_id]) for op_id in op_id_mask])
1757
        remap_list(new_proto.external_input)
1758
        remap_list(new_proto.external_output)
1759
        new_net = Net(new_proto)
1760

1761
        if keep_schema:
1762
            from caffe2.python import schema
1763
            if self._input_record:
1764
                new_net._input_record = schema.from_blob_list(
1765
                    self._input_record,
1766
                    [
1767
                        BlobReference(get_remapped_str(blob), net=new_net)
1768
                        for blob in self._input_record.field_blobs()
1769
                    ],
1770
                )
1771
            if self._output_record:
1772
                new_net._output_record = schema.from_blob_list(
1773
                    self._output_record,
1774
                    [
1775
                        BlobReference(get_remapped_str(blob), net=new_net)
1776
                        for blob in self._output_record.field_blobs()
1777
                    ],
1778
                )
1779

1780
        new_net._attr_dict.update(self._attr_dict)
1781
        if update_external_list:
1782
            # external input list
1783
            existing_outputs = set()
1784
            used_outputs = set()
1785
            del new_net.Proto().external_input[:]
1786
            del new_net.Proto().external_output[:]
1787
            for op in new_net.Proto().op:
1788
                for ib in op.input:
1789
                    if ib not in existing_outputs:
1790
                        new_net.Proto().external_input.extend([ib])
1791
                    else:
1792
                        used_outputs.add(ib)
1793
                for ob in op.output:
1794
                    existing_outputs.add(ob)
1795
            # external outputs
1796
            for ob in existing_outputs:
1797
                if ob not in used_outputs:
1798
                    new_net.Proto().external_output.extend([ob])
1799
        return new_net
1800

1801
    def ClonePartial(self, name, inputs, outputs, remap_funcs=None):
1802
        """
1803
        Clone this net, including only ops that are necessary in order to
1804
        compute `outputs` given `inputs`. Return references to the cloned
1805
        outputs. Internal blobs (blobs that are produced and consumed inside
1806
        the net but not used as outputs) will be remapped to avoid name
1807
        conflict.
1808

1809
        Args:
1810
            name:    the name of the cloned net
1811
            inputs:  map where the keys correspond to BlobReferences in the
1812
                     original net, and the values correspond to external inputs
1813
                     in the partially cloned net. If `inputs` is a list, don't
1814
                     remap input names.
1815
            outputs: outputs to be produced by the cloned net.
1816

1817
        Returns:
1818
            Tuple (new_net, new_outputs)
1819
                new_net:       a new Net object.
1820
                new_outputs:   list of BlobReferences corresponding to the
1821
                               outputs produced by new_net.
1822
        """
1823
        input_is_pair_list = isinstance(inputs, list) and all(
1824
            isinstance(i, tuple) and len(i) == 2 for i in inputs)
1825
        inputs = (
1826
            inputs if isinstance(inputs, (dict, OrderedDict)) else
1827
            OrderedDict(inputs) if input_is_pair_list else
1828
            OrderedDict(zip(inputs, inputs)))
1829
        for output in outputs:
1830
            assert self.BlobIsDefined(output), "{} is not defined".format(output)
1831
        input_names = {str(k): str(v) for k, v in inputs.items()}
1832
        output_names = [str(o) for o in outputs]
1833
        proto = self._net
1834
        blob_versions = {str(i): 0 for i in inputs}
1835
        ssa, blob_versions = get_ssa(proto, blob_versions)
1836
        used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
1837
        disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs)
1838
        assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
1839
            'Cannot partially clone net: some of the ops required would ' +
1840
            'generate the given input.')
1841

1842
        sub_ssa = [op for i, op in enumerate(ssa) if i in used_op_ids]
1843
        undef_blobs = get_undefined_blobs(sub_ssa) - set(input_names.keys())
1844
        prefix = (name + '/') if name else ''
1845

1846
        def remap(blob_name):
1847
            if blob_name in input_names:
1848
                return input_names[blob_name]
1849
            elif blob_name in undef_blobs:
1850
                return blob_name
1851
            else:
1852
                return prefix + blob_name
1853

1854
        blob_mapping = {b: remap(b) for b in blob_versions.keys()}
1855
        new_net = self.Clone(name, blob_mapping, used_op_ids, remap_funcs)
1856
        new_in = [
1857
            blob_mapping[i] for i in input_names.keys()] + list(undef_blobs)
1858
        new_out = [blob_mapping[o] for o in output_names]
1859
        del new_net.Proto().external_input[:]
1860
        new_net.Proto().external_input.extend(new_in)
1861
        new_net._external_input_map = set(list(new_in))
1862
        del new_net.Proto().external_output[:]
1863
        new_net.Proto().external_output.extend(new_out)
1864
        return new_net, [new_net.GetBlobRef(o) for o in new_out]
1865

1866
    def Proto(self):
1867
        self._InvalidateLookupTables()
1868
        return self._net
1869

1870
    def insert_op_at_idx(self, op, op_idx):
1871
        r""" inserting operator at index. Will update external blob list.
1872
        """
1873
        assert op_idx >= 0
1874
        temp_ops = self.Proto().op[op_idx:]
1875
        del self.Proto().op[op_idx:]
1876
        self.Proto().op.extend([op])
1877
        self.Proto().op.extend(temp_ops)
1878
        self.external_outputs.extend(op.output)
1879
        self.external_inputs.extend(op.input)
1880

1881
    def reroute_tensor(self, tensor, new_producer, can_modify=None):
1882
        r""" reroute tensor to new_producer. And feed new tensor to consumers
1883
        and interseciton with can_modify if provided.
1884
        Inputs:
1885
            tensor: str or blob_reference the tensor to reroute
1886
            new_producer: an op takes in tensor gives new_tesnor
1887
            can_modify: a list/set of operators that consumes tensor and can be
1888
            modified
1889

1890
        Returns:
1891
            reroute_cnt: how many consumer op has been changed
1892

1893
        Note: assume no inplace blob in net
1894
        """
1895
        def _find_tensor_input_op(tensor):
1896
            if tensor in self.external_inputs:
1897
                op_idx = -1
1898
            else:
1899
                assert tensor in new_producer.input, \
1900
                    "new producer {} is not taking in {}".format(
1901
                        new_producer.type, tensor)
1902
                # assuming that the net has no inplace blob
1903
                op_idx = -2
1904
                for index, op in enumerate(self.Proto().op):
1905
                    if_found = False
1906
                    for o in op.output:
1907
                        if o == tensor:
1908
                            # tensor should not be modified yet.
1909
                            if_found = True
1910
                            op_idx = index
1911
                            break
1912
                    if if_found:
1913
                        break
1914
            return op_idx
1915

1916
        # the place to inject new_producer is not just determined by tensor
1917
        op_idx = max(_find_tensor_input_op(t) for t in new_producer.input)
1918
        self.insert_op_at_idx(new_producer, op_idx + 1)
1919
        new_tensor = new_producer.output[0]
1920
        # modify external outputs
1921
        if tensor in self.external_outputs:
1922
            new_list = [new_tensor if b == tensor else b for b in self.external_outputs]
1923
            del self.Proto().external_output[:]
1924
            self.Proto().external_output.extend(new_list)
1925

1926
        # modify consumers
1927
        reroute_cnt = 0
1928
        if can_modify:
1929
            for op in self.Proto().op:
1930
                if op in can_modify:  # this is not necessarily true
1931
                    remap_input(op, {tensor: new_tensor})
1932
                    reroute_cnt = reroute_cnt + 1
1933
        return reroute_cnt
1934

1935
    def PopulateProtoWithFileName(self):
1936
        net_tb = workspace.operator_tracebacks.get(self.Name(), None)
1937
        if net_tb is not None:
1938
            for idx, op in enumerate(self.Proto().op):
1939
                if idx in net_tb:
1940
                    op.name = ':'.join(map(str, net_tb[idx][0]))
1941

1942
    def NextScopedBlob(self, prefix='unnamed'):
1943
        """Return the blob that has not been defined or registered in the
1944
        current net. It returns `ScopedBlobReference(prefix)`, if it's valid,
1945
        otherwise `ScopedBlobReference(prefix) + '_auto_' + ?`. Different calls
1946
        is guaranteed to return blob with different names.
1947
        """
1948
        output_blob_base = ScopedName(prefix)
1949
        return self.NextBlob(output_blob_base)
1950

1951
    def NextBlob(self, prefix='unnamed'):
1952
        """Return the blob that has not been defined or registered in the
1953
        current net. It returns `BlobReference(prefix)`, if it's valid,
1954
        otherwise `BlobReference(prefix) + '_auto_' + ?`. Different calls
1955
        is guaranteed to return blob with different names."""
1956
        output_blob_base = BlobReference(prefix)
1957
        output_blob = output_blob_base
1958
        index = 0
1959
        while str(output_blob) in self._registered_blob_names or (
1960
                self.BlobIsDefined(output_blob)):
1961
            output_blob = output_blob_base + '_auto_' + str(index)
1962
            index += 1
1963

1964
        self._registered_blob_names.add(str(output_blob))
1965
        return output_blob
1966

1967
    def NextName(self, prefix=None, output_id=None):
1968
        """Returns the next name to be used, if you do not want to explicitly
1969
        name your blob. [Deprecated, use NextBlob, NextScopedBlob instead]"""
1970
        if prefix:
1971
            output_name_base = self._net.name + '/' + prefix
1972
            output_name = output_name_base
1973
            if output_id is not None:
1974
                output_name += ':' + str(output_id)
1975
            key = output_name
1976
            index = self._next_blob_name_ids.get(key, 2)
1977
            while self.BlobIsDefined(str(ScopedBlobReference(output_name))):
1978
                output_name = output_name_base + '_' + str(index)
1979
                if output_id is not None:
1980
                    output_name += ':' + str(output_id)
1981
                index += 1
1982
                self._next_blob_name_ids[key] = index
1983
        else:
1984
            output_name = self._net.name + '_blob_' + str(self._next_name_index)
1985
            self._next_name_index += 1
1986
        return str(output_name)
1987

1988
    def _ExtendOps(self, new_ops):
1989
        self._net.op.extend(new_ops)
1990
        for op in new_ops:
1991
            self._op_outputs.update([str(o) for o in op.output])
1992

1993
    def _CheckLookupTables(self):
1994
        '''
1995
        Called from unit tests to validate the internal lookup tables
1996
        match the protobuf contents.
1997
        '''
1998
        test_op_outputs = set()
1999
        for op in self._net.op:
2000
            for o in op.output:
2001
                test_op_outputs.add(o)
2002

2003
        test_external_inp = set()
2004
        for inp in self._net.external_input:
2005
            test_external_inp.add(inp)
2006

2007
        assert test_op_outputs.difference(self._op_outputs) == set()
2008
        assert test_external_inp.difference(self._external_input_map) == set()
2009

2010
    def _InvalidateLookupTables(self):
2011
        self._recreate_lookup_tables = True
2012

2013
    def _RecreateLookupTables(self):
2014
        self._op_outputs = {o for op in self._net.op for o in op.output}
2015
        self._external_input_map = {inp for inp in self._net.external_input}
2016
        self._recreate_lookup_tables = False
2017

2018
    def AddGradientOperators(self, ys, skip=0):
2019
        """Add the gradient for operators in the net.
2020

2021
        Inputs:
2022
          ys: a list or a dictionary specifying what blobs we want to compute
2023
              derivatives of. If the input is a list, we will automatically
2024
              generate their gradients with all-one values; if the input is a
2025
              dictionary, for any dictionary entries that are not None, we will
2026
              take the corresponding blobs as their gradients; for all those
2027
              that are None, we will auto-fill them with 1.
2028
          skip: skips the first n operators. This is provided mainly because a
2029
              lot of nets may use the first few operators for data generation
2030
              like stuff which really do not need to have gradients.
2031

2032
        Outputs:
2033
          returns a map from the blob name in the input network to a blob
2034
          containing gradient or a GradientSlice in case of sparse gradient
2035

2036
        Currently, this is hard-coded for float operators if there are branches
2037
        (i.e. a blob is used as input to multiple operators). This is because
2038
        the gradient accumulation (Sum) is float only right now.
2039
        """
2040

2041
        grad_ops, input_to_grad = GradientRegistry.GetBackwardPass(
2042
            self._net.op[skip:], ys)
2043
        # Check if in immediate mode: the grad_ops are actually being produced
2044
        # by C++ and bypasses the CreateOperator() call, so in immediate mode
2045
        # we will have to explicitly run them.
2046
        if workspace.IsImmediate():
2047
            for op in grad_ops:
2048
                workspace.RunOperatorImmediate(op)
2049
        self._ExtendOps(grad_ops)
2050
        return input_to_grad
2051

2052
    def AddArgument(self, arg_name, arg_value):
2053
        self._net.arg.extend([utils.MakeArgument(arg_name, arg_value)])
2054

2055
    def AddExternalInput(self, *inputs):
2056
        assert len(inputs) > 0
2057
        refs = []
2058
        input_name_set = set()
2059
        for input in inputs:
2060
            input_name = str(input)
2061
            assert (
2062
                input_name not in self._external_input_map
2063
                and input_name not in input_name_set
2064
            ), ("Net already contains an input named %s" % input_name)
2065
            input_name_set.add(input_name)
2066
        for input in inputs:
2067
            input_name = str(input)
2068
            self._net.external_input.extend([input_name])
2069
            self._external_input_map.update([input_name])
2070
            refs.append(_get_blob_ref(input_name))
2071

2072
        return refs[0] if len(refs) == 1 else refs
2073

2074
    def AddExternalOutput(self, *outputs):
2075
        for output in outputs:
2076
            assert isinstance(output, BlobReference)
2077
            assert self.BlobIsDefined(output), "{} is not defined".format(output)
2078
        for output in outputs:
2079
            self.Proto().external_output.extend([str(output)])
2080

2081
    def AddScopedExternalInputs(self, *inputs):
2082
        res = self.AddExternalInput(
2083
            * [ScopedBlobReference(b) for b in inputs]
2084
        )
2085
        if not isinstance(res, list):
2086
            res = [res]
2087
        return res
2088

2089
    def AddScopedExternalOutputs(self, *outputs):
2090
        return self.AddExternalOutput(
2091
            * [ScopedBlobReference(b) for b in outputs]
2092
        )
2093

2094
    # This returns a reference to the observer
2095
    def AddObserver(self, observer_type):
2096
        return C.add_observer_to_net(self._net.name, observer_type)
2097

2098
    def RemoveObserver(self, observer):
2099
        C.remove_observer_from_net(self._net.name, observer)
2100

2101
    def NumObservers(self):
2102
        return C.num_observers_on_net(self._net.name)
2103

2104
    @property
2105
    def external_inputs(self):
2106
        return [_get_blob_ref(x) for x in self._net.external_input]
2107

2108
    @property
2109
    def external_outputs(self):
2110
        return [_get_blob_ref(x) for x in self._net.external_output]
2111

2112
    def set_input_record(self, input_record):
2113
        from caffe2.python import schema
2114
        assert self._input_record is None or (input_record.has_blobs() and
2115
            set(input_record.field_blobs()) ==
2116
            set(self._input_record.field_blobs())), (
2117
            'Input schema cannot be reset')
2118
        if not input_record.has_blobs():
2119
            with NameScope(self.Name()):
2120
                self._input_record = schema.NewRecord(self, input_record)
2121
        else:
2122
            self._input_record = input_record
2123

2124
        for blob in self._input_record.field_blobs():
2125
            if not self.is_external_input(blob):
2126
                self.AddExternalInput(blob)
2127
        return self._input_record
2128

2129
    def recover_input_record_by_prefix(self, prefix):
2130
        """
2131
        Tries to recover input record by taking a subset of external_inputs with
2132
        a given prefix name and interpreting them as schema column names
2133
        """
2134
        record = _recover_record_by_prefix(self._net.external_input, prefix)
2135
        if record:
2136
            self.set_input_record(record)
2137

2138
    def set_output_record(self, record):
2139
        assert self._output_record is None or (record.has_blobs() and
2140
            set(record.field_blobs()) ==
2141
            set(self._output_record.field_blobs())), (
2142
            'Output schema cannot be reset')
2143
        for blob in record.field_blobs():
2144
            assert self.BlobIsDefined(blob), "{} is not defined in net {}".format(
2145
                blob,
2146
                self.Proto()
2147
            )
2148
        for blob in record.field_blobs():
2149
            if blob not in self.external_outputs:
2150
                self.AddExternalOutput(blob)
2151
        self._output_record = record
2152

2153
    def recover_output_record_by_prefix(self, prefix):
2154
        """
2155
        Tries to recover out record by taking a subset of external_outputs with
2156
        a given prefix name and interpreting them as schema column names
2157
        """
2158
        record = _recover_record_by_prefix(self._net.external_output, prefix)
2159
        if record:
2160
            self.set_output_record(record)
2161

2162
    def AppendOutputRecordField(self, field_name, record):
2163
        from caffe2.python import schema
2164
        assert self._output_record is not None, (
2165
            'Tried to append to missing output record'
2166
        )
2167
        for blob in record.field_blobs():
2168
            assert self.BlobIsDefined(blob), "{} is not defined".format(blob)
2169
        for blob in record.field_blobs():
2170
            self.AddExternalOutput(blob)
2171
        self._output_record = self._output_record + schema.Struct(
2172
            (field_name, record)
2173
        )
2174

2175
    def input_record(self):
2176
        return self._input_record
2177

2178
    def output_record(self):
2179
        return self._output_record
2180

2181
    def AddExternalInputs(self, *inputs):
2182
        return self.AddExternalInput(*inputs)
2183

2184
    def AddExternalOutputs(self, *outputs):
2185
        self.AddExternalOutput(*outputs)
2186

2187
    def DeduplicateGradientSlices(self, g, aggregator='sum'):
2188
        assert isinstance(g, GradientSlice)
2189
        unique, remapping = self.Unique([g.indices], 2, engine='SparseHash')
2190
        if aggregator.lower() == 'sum':
2191
            new_g = self.UnsortedSegmentSum([g.values, remapping], 1)
2192
        elif aggregator.lower() == 'mean':
2193
            new_g = self.UnsortedSegmentMean([g.values, remapping], 1)
2194
        else:
2195
            raise ValueError('{} is not supported'.format(aggregator))
2196
        return GradientSlice(indices=unique, values=new_g)
2197

2198
    @staticmethod
2199
    def _RunAllOnGPU(net, gpu_id=0, use_cudnn=False):
2200
        device_option = caffe2_pb2.DeviceOption()
2201
        device_option.device_type = workspace.GpuDeviceType
2202
        device_option.device_id = gpu_id
2203
        net.device_option.CopyFrom(device_option)
2204
        if use_cudnn:
2205
            for op in net.op:
2206
                op.engine = "CUDNN"
2207
        # Move RecurrentNetwork operators on GPU as well
2208
        for op in net.op:
2209
            if op.type != "RecurrentNetwork":
2210
                continue
2211
            for arg in op.arg:
2212
                if arg.name == "step_net":
2213
                    Net._RunAllOnGPU(arg.n, gpu_id, use_cudnn)
2214

2215
    def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
2216
        """A convenient function to run everything on the GPU."""
2217
        self._RunAllOnGPU(self._net, gpu_id, use_cudnn)
2218

2219

2220

2221
    def RunAllOnMKL(self):
2222
        """A convenient function to run everything using MKLDNN."""
2223
        device_option = caffe2_pb2.DeviceOption()
2224
        device_option.device_type = caffe2_pb2.MKLDNN
2225
        self._net.device_option.CopyFrom(device_option)
2226

2227
    def RunAllOnIDEEP(self):
2228
        """A convenient function to run everything using IDEEP."""
2229
        device_option = caffe2_pb2.DeviceOption()
2230
        device_option.device_type = caffe2_pb2.IDEEP
2231
        self._net.device_option.CopyFrom(device_option)
2232

2233
    def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, **kwargs):
2234
        """A helper function to create an operator and add it to self.
2235
        """
2236
        inputs = _RectifyInputOutput(inputs)
2237
        for input in inputs:
2238
            if not self.BlobIsDefined(input):
2239
                assert input.Net() != self
2240
                self.AddExternalInput(input)
2241
        if outputs is None:
2242
            # If we do not specify an output, we will assume that this op
2243
            # produces one output in this case.
2244
            outputs = self.NextName(prefix=op_type)
2245
        elif type(outputs) is int:
2246
            # In this case, we will auto-fill the given number of outputs
2247
            # with auto-generated names.
2248
            outputs = [
2249
                self.NextName(prefix=op_type, output_id=i)
2250
                for i in range(outputs)]
2251
        outputs = _RectifyInputOutput(outputs, net=self)
2252
        op = CreateOperator(op_type, inputs, outputs, **kwargs)
2253
        self._ExtendOps([op])
2254

2255
        workspace.operator_tracebacks[self.Name()][
2256
            len(self._net.op) - 1] = _extract_stacktrace()
2257

2258
        if len(op.output) == 0:
2259
            return
2260
        elif len(op.output) == 1:
2261
            return BlobReference(op.output[0], self)
2262
        else:
2263
            return tuple(BlobReference(o, self) for o in op.output)
2264

2265
    def __getattr__(self, op_type):
2266
        if op_type.startswith('__'):
2267
            raise AttributeError('Attribute {} not found.'.format(op_type))
2268
        if not IsOperator(op_type) and not IsOperatorWithEngine(op_type, "CUDNN"):
2269
            raise AttributeError(
2270
                'Method ' + op_type + ' is not a registered operator.' +
2271
                ' Did you mean: [' +
2272
                ",".join(workspace.C.nearby_opnames(op_type)) + ']'
2273
            )
2274
        return lambda *args, **kwargs: self._CreateAndAddToSelf(
2275
            op_type, *args, **kwargs)
2276

2277
    def __dir__(self):
2278
        TriggerLazyImport()
2279
        additional_methods = [
2280
            op
2281
            for op in _REGISTERED_OPERATORS
2282
            if '_ENGINE_' not in op]
2283
        return sorted(set(chain(
2284
            dir(type(self)),
2285
            self.__dict__.keys(),
2286
            additional_methods
2287
        )))
2288

2289
    def Python(
2290
        self,
2291
        f,
2292
        grad_f=None,
2293
        python_func_type=None,
2294
        pass_workspace=False,
2295
        grad_output_indices=None,
2296
        grad_input_indices=None
2297
    ):
2298
        """
2299
        Registers and returns a python operator.
2300

2301
        `f` and `grad_f` can be one of the following:
2302
            - a function with signature (inputs, outputs), where inputs and
2303
              outputs are a list of CPUTensor objects. This function will be
2304
              called from C++ everytime the operator is executed.
2305
            - a tuple (func, args, kwargs), here `func` is a callable, args is
2306
              an argument list, and kwargs is a dict list. The call:
2307
                  f = func(*args, kwargs)
2308
              will be performed locally at node initialization time, on all of
2309
              the nodes of the job, returning `f`, a callable that will be used
2310
              as the python operator function to be called during Net execution.
2311
              This is to be used when using python operator in a distributed
2312
              context, and allows to create and keep local python state across
2313
              calls to the operator.
2314

2315
        `python_func_type` is a type of an object that constructed as
2316
        python_func_type(f) and provides an implementation to forward and
2317
        backward functions. Its useful in such a case where users needs
2318
        a statefull PythonOp (ex: use autograd for computing grad_f).
2319

2320
        If `pass_workspace` is True, the signature is changed to
2321
        (inputs, outputs, workspace) where `workspace` is the workspace the op
2322
        is going to run on. This is potentially dangerous (as the op can
2323
        manipulate the workspace directly), use on your own risk.
2324

2325
        If a gradient function is specified (`grad_f`), by default its inputs
2326
        will be: (1) all inputs to `f`, (2) followed by all outputs of `f`, (3)
2327
        and then all gradient outputs of `f`. The outputs of `grad_f` will be
2328
        (by default) all gradient inputs to `f`. If a subset of the gradient
2329
        outputs or gradient inputs is desired instead, then the subsets can be
2330
        specified by providing `grad_output_indices` and/or `grad_input_indices`
2331
        which identify the indices of `f`'s inputs and outputs which have
2332
        gradients.
2333
        """
2334
        assert(IsOperator('Python'))
2335

2336
        def make_builder(t):
2337
            if not isinstance(t, tuple):
2338
                return ''
2339
            assert len(t) == 3, 'Expected builder tuple (func, args, kwargs)'
2340
            func, args, kwargs = t
2341
            normalized = (func, tuple(args), dict(kwargs))
2342
            return pickle.dumps(normalized)
2343

2344
        f_builder = make_builder(f)
2345
        grad_f_builder = make_builder(grad_f)
2346

2347
        assert (not grad_f) or ((not f_builder) == (not grad_f_builder)), (
2348
            'A tuple has to be passed to both f and grad_f or neither.')
2349

2350
        core_kwargs = {}
2351
        if f_builder:
2352
            core_kwargs['pickled_builder'] = f_builder
2353
            core_kwargs['pickled_grad_builder'] = grad_f_builder
2354
            core_kwargs['pass_workspace'] = pass_workspace
2355
        else:
2356
            core_kwargs['token'] = _RegisterPythonImpl(
2357
                f, grad_f, python_func_type, pass_workspace=pass_workspace)
2358

2359
        grad_output_indices = grad_output_indices or []
2360
        grad_input_indices = grad_input_indices or []
2361
        return lambda *args, **kwargs: self._CreateAndAddToSelf(
2362
            'Python',
2363
            grad_output_indices=grad_output_indices,
2364
            grad_input_indices=grad_input_indices,
2365
            *args,
2366
            **dict(chain(kwargs.items(), core_kwargs.items()))
2367
        )
2368

2369
    def is_external_input(self, blob):
2370
        if self._recreate_lookup_tables:
2371
            self._RecreateLookupTables()
2372

2373
        name = str(blob)
2374
        return name in self._external_input_map
2375

2376
    def extend_ops(self, new_ops):
2377
        return self._ExtendOps(new_ops)
2378

2379

2380
def remap_input(op, blob_name_remapping):
2381
    new_list = [blob_name_remapping.get(b, b) for b in op.input]
2382
    del op.input[:]
2383
    op.input.extend(new_list)
2384

2385

2386
def copy_func_between_devices(src, dst):
2387
    CPU = caffe2_pb2.CPU
2388
    is_src_gpu = IsGPUDeviceType(src.device_type)
2389
    is_dst_gpu = IsGPUDeviceType(dst.device_type)
2390

2391
    if src.device_type == CPU and dst.device_type == CPU:
2392
        return None
2393

2394
    if is_src_gpu and is_dst_gpu:
2395
        if src.device_id == dst.device_id:
2396
            return None
2397
        else:
2398
            def fun(net, *args, **kw):
2399
                with DeviceScope(dst):
2400
                    return net.Copy(*args, **kw)
2401
            return fun
2402

2403
    if is_src_gpu and dst.device_type == CPU:
2404
        def fun(net, *args, **kw):
2405
            with DeviceScope(src):
2406
                return net.CopyGPUToCPU(*args, **kw)
2407
        return fun
2408

2409
    if src.device_type == CPU and is_dst_gpu:
2410
        def fun(net, *args, **kw):
2411
            with DeviceScope(dst):
2412
                return net.CopyCPUToGPU(*args, **kw)
2413
        return fun
2414

2415
    raise ValueError('Non-supported devices: %s and %s' % (src, dst))
2416

2417

2418
def device_equal(src, dst):
2419
    '''
2420
    We are using this fucntion instead of == operator because optional-value
2421
    comparison between empty device_options and {device_type:0, device_id:0}
2422
    returns not equal in some cases.
2423
    '''
2424
    return src.device_type == dst.device_type and src.device_id == dst.device_id
2425

2426

2427
def update_placeholder_op_output(op, blob_to_device):
2428
    '''
2429
    Placeholder ops (for e.g. Recv) always runs on CPU. So ensure their
2430
    output blobs reside on CPU.
2431
    '''
2432
    outputs = []
2433
    for output in op.output:
2434
        if (output in blob_to_device and
2435
                blob_to_device[output].device_type != caffe2_pb2.CPU):
2436
            output += '_cpu'
2437
        outputs.append(output)
2438
    del op.output[:]
2439
    op.output.extend(outputs)
2440

2441

2442
class RemapEntry:
2443
    def __init__(self, blob, device):
2444
        self.blob = blob
2445
        self.device = device
2446

2447
    def __eq__(self, other):
2448
        return self.blob == other.blob and self.device == other.device
2449

2450
    def __hash__(self):
2451
        return hash(self.blob + str(self.device))
2452

2453

2454
def InjectCrossDeviceCopies(net, blob_to_device=None, blob_remap=None,
2455
                            placeHolderOps=None):
2456
    '''
2457
    Injecting Copy functions between device within a net. Users can provide
2458
    a net with part of operators using different device_options. This method
2459
    will automatically create a new net with Copy ops inserted in it.
2460

2461
    Inputs:
2462
      blob_to_device: If not None, it is a map of blobs and their device locations.
2463
      blob_remap: If not None, it is a map from a pair (blob, device) to
2464
                  the name of the blob in the given device. Blobs found in this
2465
                  map are assumed to be cached and don't need to be copied.
2466
    Outputs:
2467
      new_net: A new net with CopyCPUToGPU inserted with correct device option
2468

2469
      required_external_to_device:
2470
               A mapping between unresolved external inputs and their
2471
               required device options.
2472
    Assumptions:
2473
      1. every external inputs of this net is already in blob_to_device!
2474
      2. if not, this function will use net device option
2475
      3. InferOpBlobDevices might fail to get the correct inference for ops like
2476
         EnsureCPUOutput that could take in input from multiple places.
2477
    '''
2478
    new_net = net.Clone(net._net.name + '_cross_device', keep_schema=True)
2479
    del new_net._net.op[:]
2480
    if blob_to_device is None:
2481
        blob_to_device = {}
2482
    # remapping of input blobs for each op.
2483
    if blob_remap is None:
2484
        blob_remap = {}
2485
    temp_remap = {}
2486
    net_option = net._net.device_option or caffe2_pb2.DeviceOption()
2487

2488
    # if external_inputs have device remappings generated by previous nets,
2489
    # then add those remappings as external inputs as well.
2490
    all_remaps = defaultdict(list)
2491
    for entry, mapped_blob in blob_remap.items():
2492
        all_remaps[entry.blob].append(mapped_blob)
2493
    mapped_external_inputs = []
2494
    for input in new_net._net.external_input:
2495
        mapped_external_inputs.extend(all_remaps.get(input) or [])
2496
    new_net._net.external_input.extend(mapped_external_inputs)
2497

2498
    for op in net._net.op:
2499
        temp_remap.clear()
2500
        # Get where inputs and outputs should be. If it is a Placeholder
2501
        # (i.e. fake) op, then set op's device as blob's devices.
2502
        input_dev = None
2503
        output_dev = None
2504
        if placeHolderOps is not None and op.type in placeHolderOps:
2505
            input_dev, output_dev = InferOpDeviceAsBlobDevices(op)
2506
        else:
2507
            input_dev, output_dev = InferOpBlobDevices(op)
2508

2509
        for dev, input in zip(input_dev, op.input):
2510
            assert net.BlobIsDefined(input), \
2511
                "input {} should be defined in the net.".format(input)
2512
            if input not in blob_to_device:
2513
                if net.is_external_input(input):
2514
                    blob_to_device[input] = net_option
2515
                else:
2516
                    raise AttributeError(
2517
                        "No device information found for blob {}.".
2518
                        format(input)
2519
                    )
2520

2521
            if not device_equal(blob_to_device[input], dev):
2522
                # reuse already moved input
2523
                if (RemapEntry(input, dev) in blob_remap and
2524
                        blob_to_device[blob_remap[RemapEntry(input, dev)]] == dev):
2525
                    temp_remap[input] = blob_remap[RemapEntry(input, dev)]
2526
                else:
2527
                    # need to make input on correct device.
2528
                    copy_func = copy_func_between_devices(
2529
                        blob_to_device[input], dev
2530
                    )
2531

2532
                    def _gen_new_name(blob, device_option):
2533
                        CPU = caffe2_pb2.CPU
2534
                        if device_option.device_type == CPU:
2535
                            suffix = '_cpu'
2536
                        elif IsGPUDeviceType(device_option.device_type):
2537
                            suffix = '_gpu_' + str(device_option.device_id)
2538
                        else:
2539
                            raise RuntimeError(
2540
                                "Unknown device type: {}".
2541
                                format(device_option.device_type)
2542
                            )
2543
                        return blob + suffix
2544

2545
                    new_name = _gen_new_name(input, dev)
2546
                    copy_func(new_net, input, new_name)
2547
                    blob_remap[RemapEntry(input, dev)] = new_name
2548
                    temp_remap[input] = new_name
2549
                    blob_to_device[new_name] = dev
2550

2551
        if placeHolderOps is not None and op.type in placeHolderOps:
2552
            update_placeholder_op_output(op, blob_to_device)
2553

2554
        # Enforcing no reuse blob between operators. In-place blob usage in an
2555
        # op is allowed. This is based on the assumption that in-place op has
2556
        # same device info
2557
        for dev, output in zip(output_dev, op.output):
2558
            if output in blob_to_device and (
2559
                output not in op.input and
2560
                not device_equal(blob_to_device[output], dev)
2561
            ):
2562
                raise RuntimeError(
2563
                    "In-place blob: {} is not supported between operators "
2564
                    "with different device option previous:{} now: {}. "
2565
                    "Failed op:\n {}".format(
2566
                        output, blob_to_device[output], dev, op
2567
                    )
2568
                )
2569
        new_op = caffe2_pb2.OperatorDef()
2570
        new_op.CopyFrom(op)
2571

2572
        new_list = [temp_remap.get(b, b) for b in new_op.input]
2573
        del new_op.input[:]
2574
        new_op.input.extend(new_list)
2575

2576
        # keep inplace blobs inplace
2577
        original_inputs = list(op.input)
2578
        for i, out in enumerate(new_op.output):
2579
            try:
2580
                input_idx = original_inputs.index(out)
2581
                new_op.output[i] = new_op.input[input_idx]
2582
            except ValueError:
2583
                pass
2584

2585
        blob_to_device.update(
2586
            {o: d for d, o in zip(output_dev, new_op.output)})
2587
        new_net.extend_ops([new_op])
2588

2589
    return new_net, blob_to_device
2590

2591

2592
def InjectDeviceCopiesAmongNets(nets, blob_to_device_init=None):
2593
    """
2594
    Takes in a list of nets. They usually represent your whole execution graph.
2595
    This function will insert cross device copy functions to all nets, and resolve
2596
    inter-net external inputs dependencies. This method will insert Copy funcitons if
2597
    external inputs of a net is produced on different device than it is required.
2598
    Inputs:
2599
      nets: a list of nets
2600
    Outputs:
2601
      new_nets: a list of new nets with device difference solved.
2602

2603
    Some notes from wyiming:
2604
      1. You MUST pass nets in execution order. e.g. [train_init, train]
2605
    """
2606
    assert isinstance(nets, list), \
2607
        "nets {} should be a list of nets.".format(str(nets))
2608
    assert all(isinstance(net, Net) for net in nets), \
2609
        "nets {} should be a list of nets.".format(str(nets))
2610
    # A holistic blob to device mapping.
2611
    blob_to_device = blob_to_device_init or {}
2612
    blob_remap = {}
2613
    new_nets = []
2614

2615
    for net in nets:
2616
        new_net, blob_to_device = InjectCrossDeviceCopies(
2617
            net,
2618
            blob_to_device=blob_to_device,
2619
            blob_remap=blob_remap,
2620
        )
2621
        new_nets.append(new_net)
2622

2623
    return new_nets, blob_to_device
2624

2625

2626
def InjectDeviceCopiesAmongNetsWithoutB2D(nets, blob_to_device_init=None):
2627
    new_nets, _ = InjectDeviceCopiesAmongNets(nets, blob_to_device_init)
2628
    return new_nets
2629

2630

2631
def get_net_name(netlike):
2632
    if isinstance(netlike, Net):
2633
        return netlike.Proto().name
2634
    elif isinstance(netlike, caffe2_pb2.NetDef):
2635
        return netlike.name
2636
    else:
2637
        return netlike
2638

2639

2640
def output_to_list(op_output):
2641
    """
2642
    Ensures that the output of an operator is a list.
2643
    Use when an operator has a variable number of outputs, but a list of
2644
    outputs is desired even when number of outputs is 1.
2645

2646
    Args:
2647
        op_output: Either a BlobReferenece or an iterable of BlobReferences.
2648

2649
    Returns:
2650
        A list of BlobReferences.
2651
    """
2652
    assert type(op_output) in (list, tuple, BlobReference)
2653
    return (
2654
        [op_output]
2655
        if isinstance(op_output, BlobReference) else list(op_output))
2656

2657

2658
def _add_net_to_dict(net_dict, net):
2659
    name = get_net_name(net)
2660
    if name in net_dict:
2661
        assert net_dict[name] is None or net == net_dict[name], (
2662
            'Different nets with same name: ' + name)
2663
        return False
2664
    else:
2665
        net_dict[name] = net if isinstance(net, Net) else None
2666
        return True
2667

2668

2669
class ExecutionStep:
2670
    _step_names_used = set()
2671

2672
    @staticmethod
2673
    def _get_next_step_name(basename):
2674
        name = basename
2675
        next_idx = 1
2676
        while name in ExecutionStep._step_names_used:
2677
            name = basename + '_' + str(next_idx)
2678
            next_idx += 1
2679
        ExecutionStep._step_names_used |= set([name])
2680
        return name
2681

2682
    def __init__(self, name, nets=None, num_iter=None):
2683
        self._step = caffe2_pb2.ExecutionStep()
2684
        self._step.name = name or ExecutionStep._get_next_step_name('step')
2685
        self._net_dict = OrderedDict()
2686
        self._is_used = False
2687
        self._substeps = []
2688
        if nets is not None:
2689
            if type(nets) is Net:
2690
                nets = [nets]
2691
            for net in nets:
2692
                if _add_net_to_dict(self._net_dict, net):
2693
                    self._step.network.extend([get_net_name(net)])
2694
        if num_iter is not None:
2695
            self._step.num_iter = num_iter
2696

2697
    def get_net(self, name):
2698
        return self._net_dict[name]
2699

2700
    def Name(self):
2701
        return self._step.name
2702

2703
    def __str__(self):
2704
        return self._step.name
2705

2706
    def _assert_can_mutate(self):
2707
        assert not self._is_used, (
2708
            'Cannot mutate a step that has already been added to a plan/step.')
2709

2710
    def _notify_is_used(self):
2711
        self._is_used = True
2712

2713
    def Proto(self):
2714
        return self._step
2715

2716
    def HasNets(self):
2717
        return self._step.network is not None and (
2718
            len(self._step.network) > 0)
2719

2720
    def HasSubsteps(self):
2721
        return self._step.substep is not None and (
2722
            len(self._step.substep) > 0)
2723

2724
    def Nets(self):
2725
        return list(self._net_dict.values())
2726

2727
    def Substeps(self):
2728
        return self._substeps
2729

2730
    def SetIter(self, num_iter):
2731
        self._assert_can_mutate()
2732
        self._step.num_iter = num_iter
2733

2734
    def SetCreateWorkspace(self, create_workspace):
2735
        self._assert_can_mutate()
2736
        self._step.create_workspace = create_workspace
2737

2738
    def SetNumConcurrentInstances(self, num_concurrent_instances):
2739
        self._assert_can_mutate()
2740
        self._step.num_concurrent_instances = num_concurrent_instances
2741

2742
    def SetOnlyOnce(self, only_once):
2743
        self._assert_can_mutate()
2744
        self._step.only_once = only_once
2745

2746
    def SetShouldStopBlob(self, should_stop_blob):
2747
        assert isinstance(should_stop_blob, BlobReference), (
2748
            "expects BlobReference here, got {}".format(type(should_stop_blob)))
2749
        self._assert_can_mutate()
2750
        self._step.should_stop_blob = str(should_stop_blob)
2751

2752
    def RunEveryMillis(self, interval):
2753
        """
2754
        Run this step every interval millisecods, as long as its
2755
        siblings are still running. It is guaranteed that, after all
2756
        siblings finish, this step will run at least one.
2757

2758
        This property is ignored for top-level ExecutionSteps.
2759
        """
2760
        self._step.run_every_ms = interval
2761

2762
    def SetReportNet(self, report_net, report_interval):
2763
        """ DEPRECATED. Use RunEveryMillis instead. """
2764
        self._assert_can_mutate()
2765
        _add_net_to_dict(self._net_dict, report_net)
2766
        self._step.report_net = get_net_name(report_net)
2767
        self._step.report_interval = report_interval
2768

2769
    def AddSubstep(self, substep):
2770
        self._assert_can_mutate()
2771
        assert not self.HasNets(), 'Cannot have both network and substeps.'
2772
        if isinstance(substep, ExecutionStep):
2773
            substep._notify_is_used()
2774
            if not substep.HasNets() and not substep.HasSubsteps():
2775
                return self
2776
            for net in substep.Nets():
2777
                _add_net_to_dict(self._net_dict, net)
2778
            self._substeps.append(substep)
2779
            proto = substep.Proto()
2780
        else:
2781
            proto = substep
2782
        self._step.substep.add().CopyFrom(proto)
2783
        return self
2784

2785
    def SetConcurrentSubsteps(self, concurrent_substeps):
2786
        self._assert_can_mutate()
2787
        assert not self.HasNets(), 'Cannot have both network and substeps.'
2788
        self._step.concurrent_substeps = concurrent_substeps
2789

2790
    def AddNet(self, net):
2791
        self._assert_can_mutate()
2792
        assert not self.HasSubsteps(), 'Cannot have both network and substeps.'
2793
        assert isinstance(net, Net)
2794
        _add_net_to_dict(self._net_dict, net)
2795
        self._step.network.extend([get_net_name(net)])
2796
        return self
2797

2798
    def get_all_attributes(self, name):
2799
        """
2800
        Return the list of all attributes under the given `name`, present in
2801
        all of the nets used in this execution step and its children.
2802
        """
2803
        return [
2804
            attr
2805
            for net in self._net_dict.values()
2806
            for attr in net.get_attributes(name)
2807
        ]
2808

2809
    @classmethod
2810
    def create_from_proto(cls, step_proto, net_obj_dict, net_proto_dict):
2811
        """
2812
        Create ExecutionStep from ExecutionStep protobuf recursively
2813
        """
2814
        assert isinstance(step_proto, caffe2_pb2.ExecutionStep)
2815
        assert (len(step_proto.network) > 0 and len(step_proto.substep) == 0) or \
2816
            (len(step_proto.network) == 0 and len(step_proto.substep) > 0)
2817

2818
        steps_or_nets = []
2819
        if len(step_proto.substep) > 0:
2820
            for substep_proto in step_proto.substep:
2821
                steps_or_nets.append(ExecutionStep.create_from_proto(
2822
                    substep_proto, net_obj_dict, net_proto_dict))
2823
        else:
2824
            for net_name in step_proto.network:
2825
                if net_name not in net_obj_dict:
2826
                    assert net_name in net_proto_dict
2827
                    net = Net(net_proto_dict[net_name])
2828
                    net_obj_dict[net_name] = net
2829
                net = net_obj_dict[net_name]
2830
                assert isinstance(net, Net)
2831
                steps_or_nets.append(net)
2832

2833
        num_iter = step_proto.num_iter if step_proto.HasField('num_iter') else None
2834
        concurrent_substeps = step_proto.concurrent_substeps if\
2835
            step_proto.HasField('concurrent_substeps') else None
2836
        should_stop_blob = BlobReference(step_proto.should_stop_blob) if\
2837
            step_proto.HasField('should_stop_blob') else None
2838
        only_once = step_proto.only_once if\
2839
            step_proto.HasField('only_once') else None
2840
        num_concurrent_instances = step_proto.num_concurrent_instances if\
2841
            step_proto.HasField('num_concurrent_instances') else None
2842
        create_workspace = step_proto.create_workspace if\
2843
            step_proto.HasField('create_workspace') else None
2844
        run_every_ms = step_proto.run_every_ms if\
2845
            step_proto.HasField('run_every_ms') else None
2846

2847
        return execution_step(
2848
            step_proto.name,
2849
            steps_or_nets,
2850
            num_iter=num_iter,
2851
            report_net=None,        # DEPRECATED
2852
            report_interval=None,   # DEPRECATED
2853
            concurrent_substeps=concurrent_substeps,
2854
            should_stop_blob=should_stop_blob,
2855
            only_once=only_once,
2856
            num_concurrent_instances=num_concurrent_instances,
2857
            create_workspace=create_workspace,
2858
            run_every_ms=run_every_ms)
2859

2860

2861
def add_nets_in_order(step, net_list):
2862
    proto = step.Proto()
2863
    for substep in step.Substeps():
2864
        add_nets_in_order(substep, net_list)
2865
    for net in proto.network:
2866
        if net not in net_list:
2867
            net_list.append(net)
2868
    # FIXME(azzolini): This is actually wrong. Report nets should be
2869
    # instantiated first since they may run before any substep is run.
2870
    # However, curerntly, Reporter depends on this behavior.
2871
    if proto.report_net and proto.report_net not in net_list:
2872
        net_list.append(proto.report_net)
2873

2874

2875
class Plan:
2876

2877
    def __init__(self, name_or_step):
2878
        self._plan = caffe2_pb2.PlanDef()
2879
        self._net_dict = OrderedDict()
2880
        self._steps = []    # A list of ExecutionStep
2881
        if isinstance(name_or_step, ExecutionStep):
2882
            self._plan.name = name_or_step.Name()
2883
            self.AddStep(name_or_step)
2884
        elif isinstance(name_or_step, basestring):
2885
            self._plan.name = name_or_step
2886
        else:
2887
            raise ValueError('name_or_step must be a string or ExecutionStep')
2888

2889
    def __str__(self):
2890
        return self._plan.name
2891

2892
    def Proto(self):
2893
        return self._plan
2894

2895
    def AddNets(self, nets):
2896
        for net in nets:
2897
            if _add_net_to_dict(self._net_dict, net):
2898
                assert isinstance(net, Net)
2899
                self._plan.network.add().CopyFrom(net.Proto())
2900

2901
    def Nets(self):
2902
        return list(self._net_dict.values())
2903

2904
    def AddStep(self, step):
2905
        assert isinstance(step, ExecutionStep)
2906
        step._notify_is_used()
2907
        if not step.HasNets() and not step.HasSubsteps():
2908
            return
2909
        self._plan.execution_step.add().CopyFrom(step.Proto())
2910
        self._steps.append(step)
2911
        # nets need to be added to the plan in order of usage
2912
        net_list = []
2913
        add_nets_in_order(step, net_list)
2914
        self.AddNets([step.get_net(n) for n in net_list])
2915

2916
    def Steps(self):
2917
        return self._steps
2918

2919
    def get_all_attributes(self, name):
2920
        """
2921
        Return the list of all attributes under the given `name`, present in
2922
        all of the nets used in this plan.
2923
        """
2924
        return [
2925
            attr
2926
            for net in self._net_dict.values()
2927
            for attr in net.get_attributes(name)
2928
        ]
2929

2930
    @classmethod
2931
    def create_from_proto(cls, plan_proto):
2932
        assert isinstance(plan_proto, caffe2_pb2.PlanDef)
2933
        plan = Plan(plan_proto.name)
2934
        plan._plan.CopyFrom(plan_proto)
2935
        del plan._plan.network[:]
2936
        del plan._plan.execution_step[:]
2937

2938
        net_obj_dict = {}
2939
        net_proto_dict = {}
2940
        for net_proto in plan_proto.network:
2941
            assert net_proto.name not in net_proto_dict
2942
            net_proto_dict[net_proto.name] = net_proto
2943

2944
        for step_proto in plan_proto.execution_step:
2945
            step = ExecutionStep.create_from_proto(
2946
                step_proto, net_obj_dict, net_proto_dict)
2947
            plan.AddStep(step)
2948

2949
        return plan
2950

2951

2952
def to_execution_step(step_or_nets, default_name=None):
2953
    from caffe2.python.net_builder import NetBuilder
2954
    if isinstance(step_or_nets, ExecutionStep):
2955
        return step_or_nets
2956

2957
    stop_blob = None
2958
    if not default_name and hasattr(step_or_nets, 'name'):
2959
        default_name = step_or_nets.name
2960
    if isinstance(step_or_nets, NetBuilder):
2961
        stop_blob = step_or_nets._stop_blob
2962
        step_or_nets = step_or_nets.get()
2963
    return execution_step(
2964
        default_name, step_or_nets, should_stop_blob=stop_blob)
2965

2966

2967
def execution_step(default_name,
2968
                   steps_or_nets,
2969
                   num_iter=None,
2970
                   report_net=None,
2971
                   report_interval=None,
2972
                   concurrent_substeps=None,
2973
                   should_stop_blob=None,
2974
                   only_once=None,
2975
                   num_concurrent_instances=None,
2976
                   create_workspace=False,
2977
                   run_every_ms=None):
2978
    """
2979
    Helper for creating an ExecutionStep.
2980
    - steps_or_nets can be:
2981
      - None
2982
      - Net
2983
      - ExecutionStep
2984
      - list<Net>
2985
      - list<ExecutionStep>
2986
    - should_stop_blob is either None or a scalar boolean blob.
2987
      - This blob is checked AFTER every substeps/subnets.
2988
      - If specified and true, then this step will return immediately.
2989
      - Be sure to handle race conditions if setting from concurrent threads.
2990
    - if no should_stop_blob or num_iter is provided, defaults to num_iter=1
2991
    """
2992
    assert should_stop_blob is None or num_iter is None, (
2993
        'Cannot set both should_stop_blob and num_iter.')
2994
    if should_stop_blob is None and num_iter is None:
2995
        num_iter = 1
2996

2997
    step = ExecutionStep(default_name)
2998
    if should_stop_blob is not None:
2999
        step.SetShouldStopBlob(should_stop_blob)
3000
    if num_iter is not None:
3001
        step.SetIter(num_iter)
3002
    if only_once is not None:
3003
        step.SetOnlyOnce(only_once)
3004
    if concurrent_substeps is not None:
3005
        step.SetConcurrentSubsteps(concurrent_substeps)
3006
    if report_net is not None:
3007
        assert report_interval is not None
3008
        step.SetReportNet(report_net, report_interval)
3009
    if num_concurrent_instances is not None:
3010
        step.SetNumConcurrentInstances(num_concurrent_instances)
3011
    if create_workspace:
3012
        step.SetCreateWorkspace(True)
3013
    if run_every_ms:
3014
        step.RunEveryMillis(run_every_ms)
3015

3016
    if isinstance(steps_or_nets, ExecutionStep):
3017
        step.AddSubstep(steps_or_nets)
3018
    elif isinstance(steps_or_nets, Net):
3019
        step.AddNet(steps_or_nets)
3020
    elif isinstance(steps_or_nets, list):
3021
        if all(isinstance(x, Net) for x in steps_or_nets):
3022
            for x in steps_or_nets:
3023
                step.AddNet(x)
3024
        else:
3025
            for x in steps_or_nets:
3026
                step.AddSubstep(to_execution_step(x))
3027
    elif steps_or_nets:
3028
        raise ValueError(
3029
            'steps_or_nets must be a step, a net, or a list of nets or steps.')
3030
    return step
3031

3032

3033
def scoped_execution_step(name, *args, **kwargs):
3034
    """Same as execution_step() except that the step name is scoped."""
3035
    default_name = ScopedName(name) if name else name
3036
    return execution_step(default_name, *args, **kwargs)
3037

3038

3039
def _extract_stacktrace():
3040
    '''
3041
    This function extracts stacktrace without file system access
3042
    by purely using sys._getframe() and removes part that belongs to
3043
    this file (core.py). We are not using inspect module because
3044
    its just a wrapper on top of sys._getframe() whose
3045
    logic is based on accessing source files on disk - exactly what
3046
    we are trying to avoid here. Same stands for traceback module
3047

3048
    The reason for file system access avoidance is that
3049
    if code is located on an NFS, file access might be slow
3050

3051
    Function returns a list of tuples (file_name, line_number, function)
3052
    '''
3053

3054
    result = []
3055
    # Ignore top 3 layers of stack: this function, _CreateAndAddToSelf, and
3056
    # whatever calls _CreateAndAddToSelf (either __getattr__ or Python)
3057
    frame = sys._getframe(3)
3058
    # We just go down the frame stack in a loop
3059
    while frame:
3060
        # Its important to extract information from the frame here
3061
        # as frame's current line most probably will change later.
3062
        result.append((frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name))
3063
        frame = frame.f_back
3064
    return result
3065

3066

3067
SetPerOpEnginePref = C.set_per_op_engine_pref
3068
SetGlobalEnginePref = C.set_global_engine_pref
3069
SetEnginePref = C.set_engine_pref
3070
SetOpEnginePref = C.set_op_engine_pref
3071

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.