8
from collections import namedtuple, OrderedDict, defaultdict
9
from past.builtins import basestring
10
from itertools import chain
11
from typing import Dict, Set
13
from caffe2.proto import caffe2_pb2
14
from caffe2.python import scope, utils, workspace
15
from caffe2.python.lazy import TriggerLazyImport
16
from caffe2.python.control_ops_grad import \
17
gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output
19
import caffe2.python._import_c_extension as C
29
if (sys.platform == 'darwin' and 'leveldb' in C.registered_dbs()):
30
print('If you are using homebrew leveldb on a Mac OS, you might see an '
31
'error warning you that malloc_zone_unregister() failed. This is '
32
'not a caffe2 issue but is due to the homebrew leveldb having an '
33
'incompatible memory allocator. It does not affect usage.')
36
DeviceScope = scope.DeviceScope
37
NameScope = scope.NameScope
55
ZERO_COLLISION_HASH = 14
56
REBATCHING_BUFFER = 15
62
for name, value in caffe2_pb2.TensorProto.DataType.items():
63
py_value = getattr(DataType, name, None)
66
f"DataType {name} does not match the value defined in "
67
f"caffe2.proto: {py_value} vs {value}"
74
def _GetRegisteredOperators():
75
return set(workspace.RegisteredOperators())
78
_REGISTERED_OPERATORS = _GetRegisteredOperators()
81
def RefreshRegisteredOperators(trigger_lazy=True):
84
global _REGISTERED_OPERATORS
85
_REGISTERED_OPERATORS = _GetRegisteredOperators()
93
_GLOBAL_INIT_ARGS.extend(args[1:])
97
def GetGlobalInitArgs():
98
return _GLOBAL_INIT_ARGS[:]
101
def IsOperator(op_type):
102
return IsOperatorWithEngine(op_type, engine='DEFAULT')
105
def IsOperatorWithEngine(op_type, engine):
107
return C.op_registry_key(op_type, engine) in _REGISTERED_OPERATORS
110
def IsGPUDeviceType(device_type):
111
return device_type in {caffe2_pb2.CUDA, caffe2_pb2.HIP}
122
option = caffe2_pb2.DeviceOption()
123
option.device_type = device_type
124
option.device_id = device_id
125
if node_name is not None:
126
option.node_name = node_name
127
if random_seed is not None:
128
option.random_seed = random_seed
129
if numa_node_id is not None:
130
assert device_type == caffe2_pb2.CPU
131
option.numa_node_id = numa_node_id
132
if extra_info is not None:
133
option.extra_info.extend(extra_info)
137
def device_option_equal(opt1, opt2, ignore_node_name=True, ignore_random_seed=True):
138
if not opt1 or not opt2:
140
if not ignore_node_name and opt1.node_name != opt2.node_name:
142
if not ignore_random_seed and opt1.random_seed != opt2.random_seed:
144
if not opt1.device_type or not opt2.device_type:
146
return not opt1.device_type and not opt2.device_type
147
return opt1.device_id == opt2.device_id
150
def InferBlobDevices(net):
152
Compute mapping from parameters to devices by looking at the
153
device option of the op that creates the blob has
156
for op in net.Proto().op:
157
op_device = op.device_option
158
if op_device is None:
159
op_device = caffe2_pb2.DeviceOption(caffe2_pb2.CPU)
162
mapping[b] = op_device
166
def InferOpBlobDevicesAsDict(op):
167
input_dev_list, output_dev_list = InferOpBlobDevices(op)
169
op.input[i]: input_dev_list[i]
170
for i in range(len(op.input))
173
op.output[i]: output_dev_list[i]
174
for i in range(len(op.output))
176
return input_dict, output_dict
179
def InferOpBlobDevices(op):
180
device_info = C.infer_op_input_output_device(op.SerializeToString())
183
for dev_str in device_info[0]:
184
device_option = caffe2_pb2.DeviceOption()
185
device_option.ParseFromString(dev_str)
186
input_info.append(device_option)
187
for dev_str in device_info[1]:
188
device_option = caffe2_pb2.DeviceOption()
189
device_option.ParseFromString(dev_str)
190
output_info.append(device_option)
191
return input_info, output_info
194
def InferOpDeviceAsBlobDevices(op):
195
op_dev = op.device_option if op.device_option else caffe2_pb2.DeviceOption()
196
input_dev = [op_dev] * len(op.input)
197
output_dev = [op_dev] * len(op.output)
198
return input_dev, output_dev
201
GradientSlice = namedtuple('GradientSlice', ['indices', 'values'])
205
"""A wrapper around a blob in a net.
207
BlobReference gives us a way to refer to the network that the blob is
208
generated from. Note that blobs are, essentially, just strings in the
212
def __init__(self, name, net=None):
213
"""Initializes a blob reference.
215
Note that this does not prepends the namescope. If needed, use
216
ScopedBlobReference() to prepend the existing namespace.
218
if isinstance(name, str):
220
elif isinstance(name, bytes):
221
self._name = name.decode('utf-8')
223
self._name = str(name)
230
return hash(self._name)
232
def __eq__(self, other):
233
if isinstance(other, str):
234
return self._name == other
235
elif isinstance(other, bytes):
236
return self._name == other.decode('utf-8')
237
elif isinstance(other, BlobReference):
238
return self._name == other._name
242
def __ne__(self, other):
243
return not(self == other)
249
return 'BlobReference("{}")'.format(self._name)
251
def __add__(self, other):
252
if not isinstance(other, str):
253
raise RuntimeError('Cannot add BlobReference to a non-string.')
254
return BlobReference(self._name + other, self._from_net)
256
def __radd__(self, other):
257
if not isinstance(other, str):
258
raise RuntimeError('Cannot add a non-string to BlobReference.')
259
return BlobReference(other + self._name, self._from_net)
262
return self._from_net
264
def GetNameScope(self):
265
return self._name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]
267
def GetUnscopedName(self):
268
return self._name[self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]
270
def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
271
"""Internal function that routes the operator generation to the
272
network's __getattr__ function.
274
inputs = [] if inputs is None else inputs
275
if isinstance(inputs, BlobReference) or isinstance(inputs, str):
278
inputs.insert(0, self)
279
return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs)
281
def __getattr__(self, op_type):
282
"""A wrapper allowing one to initiate operators from a blob reference.
284
Example: for a blob reference b that comes from network n, doing
286
is equivalent to doing
289
if op_type.startswith('__'):
290
raise AttributeError('Attribute {} not found.'.format(op_type))
291
if self._from_net is None:
292
raise AttributeError(
293
'You cannot use a blob reference that does not have a net '
294
'source to create operators. Create the operator from an '
295
'explicit net object.')
296
if not IsOperator(op_type):
297
raise AttributeError(
298
'Method ' + op_type + ' is not a registered operator.' +
300
",".join(workspace.C.nearby_opnames(op_type)) + ']'
302
return lambda *args, **kwargs: self._CreateAndAddToNet(
303
op_type, *args, **kwargs)
307
additional_methods = [
309
for op in _REGISTERED_OPERATORS
310
if '_ENGINE_' not in op or '_ENGINE_CUDNN' in op]
311
return sorted(set(chain(
313
self.__dict__.keys(),
319
"""prefix the name with the current scope."""
320
if isinstance(name, bytes):
321
name = name.decode('ascii')
322
return scope.CurrentNameScope() + name
325
def ScopedBlobReference(name, *args, **kwargs):
326
"""Returns a blob reference with scope prefixed."""
327
return BlobReference(ScopedName(name), *args, **kwargs)
330
def _RectifyInputOutput(blobs, net=None):
331
"""A helper function to rectify the input or output of the CreateOperator
334
if isinstance(blobs, (bytes, str)):
338
return [ScopedBlobReference(blobs, net=net)]
339
elif type(blobs) is BlobReference:
342
elif type(blobs) in (list, tuple):
346
if isinstance(blob, (bytes, str)):
347
rectified.append(ScopedBlobReference(blob, net=net))
348
elif type(blob) is BlobReference:
349
rectified.append(blob)
352
"I/O blob #{} of unsupported type: {} of type {}"
353
.format(len(rectified), str(blob), type(blob)))
357
"Unknown input/output type: %s of type %s." %
358
(str(blobs), type(blobs))
374
"""A function wrapper that allows one to create operators based on the
375
operator type. The type should be a string corresponding to an operator
376
registered with Caffe2.
378
operator = caffe2_pb2.OperatorDef()
379
if (os.environ.get('CAFFE2_DEBUG')):
380
stack = traceback.format_stack()
381
operator.debug_info = "".join(stack[:-1])
383
operator.type = operator_type
386
inputs = _RectifyInputOutput(inputs)
387
outputs = _RectifyInputOutput(outputs)
388
operator.input.extend(map(str, inputs))
389
operator.output.extend(map(str, outputs))
391
control_input = _RectifyInputOutput(control_input)
392
operator.control_input.extend(map(str, control_input))
398
if device_option is not None:
399
operator.device_option.CopyFrom(device_option)
400
elif scope.CurrentDeviceScope() is not None:
401
operator.device_option.CopyFrom(scope.CurrentDeviceScope())
402
if engine is not None:
403
operator.engine = engine
404
if debug_info is not None:
405
operator.debug_info = debug_info
409
if 'random_seed' in kwargs:
410
operator.device_option.random_seed = kwargs['random_seed']
411
del kwargs['random_seed']
414
operator.arg.extend(arg)
416
for key, value in kwargs.items():
417
if value is not None:
418
operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
420
if workspace.IsImmediate():
421
workspace.RunOperatorImmediate(operator)
425
def _RegisterPythonImpl(
426
f, grad_f=None, python_func_type=None, pass_workspace=False
429
func = python_func_type(f)
431
grad_f = func.backward
433
if isinstance(f, tuple):
434
f = f[0](*f[1], **f[2])
435
if isinstance(grad_f, tuple):
436
grad_f = grad_f[0](*grad_f[1], **grad_f[2])
438
token = C.register_python_op(f, pass_workspace, '')
440
C.register_python_gradient_op(token, grad_f)
444
def CreatePythonOperator(
448
pass_workspace=False,
449
python_func_type=None,
454
`f` should have a signature (inputs, outputs)
456
If `pass_workspace` is True, the signature is changed to
457
(inputs, outputs, workspace) where `workspace` is the workspace the op
458
is going to run on. This is potentially dangerous (as the op can manipulate
459
the workspace directly), use on your own risk.
461
kwargs["token"] = _RegisterPythonImpl(
462
f, grad_f, python_func_type, pass_workspace=pass_workspace
464
return CreateOperator("Python", inputs, outputs, *args, **kwargs)
467
def GetIndexFromGradientList(g_list, name):
468
"""A helper function to get the index from a gradient list, None if not
470
for i, g in enumerate(g_list):
473
elif type(g) is GradientSlice:
474
if (g.indices == name or g.values == name):
479
OpSSA = namedtuple('OpSSA', ['op', 'in_versions', 'out_versions'])
480
GradGenMeta = namedtuple('GradGenMeta',
481
['grad_op', 'idx', 'gradient', 'device_option'])
482
SparseGradGenMeta = namedtuple('SparseGradGenMeta', [
483
'grad_op_indices', 'idx_indices',
484
'grad_op_values', 'idx_values',
485
'gradient', 'device_option',
490
"""A simple IR class to keep track of all intermediate representations used
491
in the gradient computation.
494
def __init__(self, operators):
513
self.input_usages = defaultdict(lambda: defaultdict(list))
514
self.frontier = defaultdict(int)
515
self.gradient_frontier = {}
516
self.gradient_generators = defaultdict(lambda: defaultdict(list))
517
self.out_version_history = defaultdict(list)
518
self.in_version_history = defaultdict(list)
523
self.SanityCheck(operators)
525
def SanityCheck(self, operators):
529
if op.type == 'StopGradient':
530
if op.output[0] not in self.input_usages:
531
raise ValueError("""StopGradient's output '{}' is orphan.
532
You typically want to specify same input and output for
533
StopGradient. Op:\n\n{}""".format(op.output[0], str(op)))
536
""""Adds an op to the current IR, and update the internal states to
537
reflect the blobs and versions after the execution of the op.
542
in_versions[s] = self.frontier[s]
543
self.input_usages[s][self.frontier[s]].append(len(self.ssa))
544
self.in_version_history[s].append((op, self.frontier[s]))
549
if s in self.frontier:
550
self.frontier[s] += 1
551
out_versions[s] = self.frontier[s]
552
self.out_version_history[s].append((op, self.frontier[s]))
554
self.ssa.append(OpSSA(op, in_versions, out_versions))
556
def CheckGradientOperatorInput(
557
self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs):
558
"""Checks if the gradient operators can be correctly carried out."""
559
forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
560
original_index = GetIndexFromGradientList(g_output, grad_op_input)
563
def versionMismatchInfoOut(name):
565
s += "Maybe you use same output blob twice for different ops?\n"
566
s += "== Version history of blob [{}]\n".format(name)
567
for (op, vers) in self.out_version_history[name]:
568
s += "Version (out) {} <-- {}".format(vers, op)
572
def versionMismatchInfoIn(name):
574
s += "Maybe the blob was overwritten by another op?\n"
575
s += "== Version history of blob [{}]\n".format(name)
576
for (op, vers) in self.in_version_history[name]:
577
s += "version (in) {} <-- {}".format(vers, op)
583
if original_index is not None:
584
original_name = forward_op.output[original_index]
585
if (out_versions[original_name] !=
586
self.gradient_frontier[original_name]):
588
'Gradient name "%s" is expected to correspond '
589
'to version %d of "%s", but currently we have '
590
'version %d.\n\n' % (
591
grad_op_input, out_versions[original_name],
593
self.gradient_frontier[original_name]) +
594
versionMismatchInfoOut(original_name))
597
elif grad_op_input in out_versions:
598
if self.frontier[grad_op_input] != out_versions[grad_op_input]:
600
'Gradient operator needs output "%s" at version'
601
' %d, but currently we have version %d.\n\n' % (
602
grad_op_input, out_versions[grad_op_input],
603
self.frontier[grad_op_input]
604
) + versionMismatchInfoOut(grad_op_input)
608
elif grad_op_input in in_versions:
609
if (self.frontier[grad_op_input] != in_versions[grad_op_input]):
611
'Gradient operator needs input "%s" at version '
612
'%d, but currently we have version %d.\n\n' % (
613
grad_op_input, in_versions[grad_op_input],
614
self.frontier[grad_op_input]
615
) + versionMismatchInfoIn(grad_op_input)
620
if grad_op_input not in locally_generated_blobs:
622
'Blob name "%s" not in the scope of operator: '
623
'%s\nand is not generated by any of the local '
624
'gradient operators.' % (grad_op_input, str(forward_op))
627
def AppendSparseGenerators(self, sparse_generators):
629
for name, input_generators in sparse_generators.items():
630
for version, generators in input_generators.items():
631
if len(generators) == 1:
633
generator = generators[0]
636
assert(len(generators) == 2)
637
op1_i, idx1_i, op1_v, idx1_v, g1, dev_1 = generators[0]
638
op2_i, idx2_i, op2_v, idx2_v, g2, dev_2 = generators[1]
640
assert dev_1 == dev_2, (
641
"Unequal devices for sparse generators: "
642
"{} and {}".format(dev_1, dev_2)
644
assert(op1_i is None or op2_i is None)
645
assert(op1_v is None or op2_v is None)
646
assert(idx1_i == 0 or idx2_i == 0)
647
assert(idx1_v == 0 or idx2_v == 0)
648
generator = SparseGradGenMeta(
649
op1_i or op2_i, idx1_i + idx2_i,
650
op1_v or op2_v, idx1_v + idx2_v,
652
self.gradient_generators[name][version].append(generator)
654
def BuildGradientGenerators(
655
self, fwd_op_idx, gradient_ops, g_output, g_input):
656
"""Updates gradient_generators and gradient_frontier"""
657
forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
658
locally_generated_blobs = []
659
sparse_generators = defaultdict(lambda: defaultdict(list))
661
for grad_op in gradient_ops:
663
for s in grad_op.input:
664
self.CheckGradientOperatorInput(
665
s, g_output, fwd_op_idx, locally_generated_blobs)
670
locally_generated_blobs.extend(map(str, grad_op.output))
671
for i, output in enumerate(grad_op.output):
672
input_index = GetIndexFromGradientList(g_input, output)
673
if input_index is not None:
674
input_name = forward_op.input[input_index]
675
input_version = in_versions[input_name]
676
g = g_input[input_index]
677
if type(g) is GradientSlice:
683
if g.indices == output:
684
m = SparseGradGenMeta(
685
grad_op, i, None, 0, g, grad_op.device_option)
687
assert(g.values == output)
688
m = SparseGradGenMeta(
689
None, 0, grad_op, i, g, grad_op.device_option)
690
sparse_generators[input_name][input_version].append(m)
692
self.gradient_generators[input_name][input_version] \
694
grad_op, i, g, grad_op.device_option))
698
self.AppendSparseGenerators(sparse_generators)
705
for input_index, g in enumerate(g_input):
706
input_name = forward_op.input[input_index]
707
input_version = in_versions[input_name]
710
if type(g) is GradientSlice:
711
if str(g.indices) not in locally_generated_blobs and \
712
str(g.values) not in locally_generated_blobs:
713
self.gradient_generators[input_name][input_version].append(
714
SparseGradGenMeta(None, 0, None, 0, g, forward_op.device_option))
716
if str(g) not in locally_generated_blobs:
717
self.gradient_generators[input_name][input_version].append(
718
GradGenMeta(None, 0, g, forward_op.device_option))
723
for i, g in enumerate(g_input):
725
input_name = forward_op.input[i]
726
input_version = in_versions[input_name]
727
self.gradient_frontier[input_name] = input_version
729
def _GetSumOpOutputName(self, generator, input_name):
730
def remove_suffix(s, suffix):
731
if s.endswith(suffix):
732
return s[:-len(suffix)]
736
if type(g) is GradGenMeta:
737
grad_op, idx, _, _ = g
739
return grad_op.output[idx]
741
assert(type(g) is SparseGradGenMeta)
742
op_i, idx_i, op_v, idx_v, _, _ = g
744
return remove_suffix(op_i.output[idx_i], '_indices')
746
return remove_suffix(op_v.output[idx_v], '_values')
748
return input_name + '_grad'
750
IS_AUTO_GEN_SUM_OPS_TAG = "is_auto_gen_sum_ops"
751
ONLY_KEEP_IS_AUTO_GEN_SUM_OPS_TAG = "only_keep_is_auto_gen_sum_ops_tag"
753
def _SetSumOpsDeviceOption(self, sum_ops, generators):
754
only_keep_is_auto_gen_sum_ops_tag = False
755
for generator in generators:
758
for extra_info in generator.device_option.extra_info:
759
if extra_info == "{}:1".format(IR.ONLY_KEEP_IS_AUTO_GEN_SUM_OPS_TAG):
760
only_keep_is_auto_gen_sum_ops_tag = True
763
if only_keep_is_auto_gen_sum_ops_tag:
769
op.device_option.extra_info.extend([
770
"{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG)
775
for generator in generators:
777
op.device_option.CopyFrom(generator.device_option)
778
op.device_option.extra_info.extend([
779
"{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG)
783
def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
785
'_' + grad_op.output[idx] + '_autosplit_{}'.format(cnt))
786
if grad_op.type == "If":
787
disambiguate_grad_if_op_output(grad_op, idx, new_grad_output)
789
grad_op.output[idx] = new_grad_output
790
return grad_op.output[idx], cnt + 1
792
def _CheckSumOpsConflict(self, out_base_name, g):
793
if str(out_base_name) == str(g):
796
'The gradient output of empty gradient op can not '
797
'be the same as the normal name of the current '
800
def _MakeDenseSumOps(self, generators, out_base_name):
804
assert len(generators) > 1
807
for generator in generators:
808
grad_op, idx, g, _ = generator
809
assert(type(g) is not GradientSlice)
812
first_grad_op = False
813
out = grad_op.output[idx]
815
out, cnt = self._DisambiguateGradOpOutput(grad_op, idx, cnt)
816
sum_op_input.append(out)
818
self._CheckSumOpsConflict(out_base_name, g)
819
sum_op_input.append(str(g))
821
if out_base_name in sum_op_input:
824
idx = sum_op_input.index(out_base_name)
825
sum_op_input[0], sum_op_input[idx] = (
826
sum_op_input[idx], sum_op_input[0]
828
sum_ops = [CreateOperator(
830
[BlobReference(x) for x in sum_op_input],
831
BlobReference(out_base_name))]
832
return sum_ops, out_base_name
834
def _MakeSparseSumOps(self, generators, out_base_name):
835
indices_concat_input = []
836
values_concat_input = []
840
for generator in generators:
841
assert(type(generator) is SparseGradGenMeta)
842
op_i, idx_i, op_v, idx_v, g, _ = generator
844
out, cnt_i = self._DisambiguateGradOpOutput(op_i, idx_i, cnt_i)
845
indices_concat_input.append(out)
847
self._CheckSumOpsConflict(out_base_name, g.indices)
848
indices_concat_input.append(g.indices)
850
out, cnt_v = self._DisambiguateGradOpOutput(op_v, idx_v, cnt_v)
851
values_concat_input.append(out)
853
self._CheckSumOpsConflict(out_base_name, g.values)
854
values_concat_input.append(g.values)
856
indices_concat_output = out_base_name + '_indices_concat'
857
indices_concat_split = out_base_name + '_indices_concat_split'
858
values_concat_output = out_base_name + '_values_concat'
859
values_concat_split = out_base_name + '_values_concat_split'
867
[BlobReference(x) for x in indices_concat_input],
868
[BlobReference(x) for x in
869
[indices_concat_output, indices_concat_split]],
874
[BlobReference(x) for x in values_concat_input],
875
[BlobReference(x) for x in
876
[values_concat_output, values_concat_split]],
880
sum_op_output = GradientSlice(
881
indices=indices_concat_output,
882
values=values_concat_output,
884
return sum_ops, sum_op_output
886
def _MakeSumOps(self, input_name, input_version):
887
generators = self.gradient_generators[input_name][input_version]
888
out_base_name = self._GetSumOpOutputName(generators, input_name)
889
types = list(set(type(x) for x in generators))
890
assert(len(types) == 1)
891
if types[0] is GradGenMeta:
892
sum_ops, g = self._MakeDenseSumOps(generators, out_base_name)
894
assert(types[0] is SparseGradGenMeta)
895
sum_ops, g = self._MakeSparseSumOps(generators, out_base_name)
896
self._SetSumOpsDeviceOption(sum_ops, generators)
899
def _VerifyGradientGenerators(self, generator):
902
if len({type(g) for g in generator}) > 1:
904
'Automatic aggregation of a mix of sparse and dense gradients '
905
'is not supported yet')
910
if len(generator) < 2:
913
all_gradient_names = []
914
all_device_options = []
917
all_device_options.append(g.device_option)
918
if type(g) is GradGenMeta:
920
all_gradient_names.append(g.gradient)
922
assert(type(g) is SparseGradGenMeta)
923
if g.gradient.values:
924
all_gradient_names.append(g.gradient.values)
927
if len(all_device_options) >= 2 and not all(
928
device_option_equal(d, all_device_options[0])
929
for d in all_device_options[1:]):
930
raise RuntimeError('Unexpected behavior: not all grad ops '
931
'have the same device option.')
934
def DoGradientAccumulation(self, fwd_op_idx):
935
"""For each input name in the forward op, check if we will need to
936
add gradient accumulation. If so, do gradient accumulation and return
937
the list of gradient operators.
939
The criteria for doing gradient accumulation is:
940
(1) the specific input version has been used by multiple operators.
941
(2) the current fwd_op_idx is the first to use that input, i.e. in the
942
backward pass, is the last to optionally generate the gradient for
944
(3) For the operators that used the input, their gradient operators
945
have generated more than 1 gradient.
947
When accumulating operators, our current solution is to rename all the
948
created gradients with an internal intermediate name, and then add a
949
Sum() operator that adds up all the gradients. This may use more memory
950
due to intermediate storage, but is usually the fastest approach as one
951
can do one single sum for multiple intermediate gradients.
953
forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
954
additional_sum_ops = []
956
for _i, input_name in enumerate(set(forward_op.input)):
957
input_version = in_versions[input_name]
958
input_usage = self.input_usages[input_name][input_version]
959
if (len(input_usage) <= 1 or fwd_op_idx != input_usage[0]):
962
generator = self.gradient_generators[input_name][input_version]
964
if not self._VerifyGradientGenerators(generator):
966
except RuntimeError as err:
968
"Gradients for param ''{}'' failed to verify: {}".format(
975
sum_ops, g = self._MakeSumOps(input_name, input_version)
976
additional_sum_ops.extend(sum_ops)
977
grad_map[input_name] = g
978
return additional_sum_ops, grad_map
980
def _AppendAutoGradGenerator(self, y, grad, autograd_op):
984
generator = GradGenMeta(
985
autograd_op, 0 if autograd_op else None, str(grad),
986
autograd_op.device_option)
988
self.gradient_generators[str(y)][self.frontier[str(y)]].append(
991
AUTOGEN_GRAD_SUFFIX = "_autogen_grad"
993
def _GetInitGradients(self, ys):
997
for y, g in ys.items():
1000
autograd_op = CreateOperator(
1001
"ConstantFill", [y], [str(y) + IR.AUTOGEN_GRAD_SUFFIX],
1003
gradient_ops.append(autograd_op)
1004
g = autograd_op.output[0]
1007
input_to_grad[str(y)] = (
1008
GradientSlice(str(g[0]), str(g[1]))
1009
if isinstance(g, GradientSlice) else str(g))
1012
if autograd_op is not None:
1013
self._AppendAutoGradGenerator(y, g, autograd_op)
1015
return input_to_grad, gradient_ops
1017
def _GenerateGradientsForForwardOp(
1018
self, forward_op_idx, input_to_grad):
1019
new_input_to_grad = {}
1021
forward_op, in_versions, out_versions = self.ssa[forward_op_idx]
1023
input_to_grad.get(name, None) for name in forward_op.output)
1025
if not all(g is None for g in g_output) or (
1026
forward_op.type == "ZeroGradient"):
1027
gradient_ops, g_input = GradientRegistry.GetGradientForOp(
1028
forward_op, g_output)
1031
self.BuildGradientGenerators(
1032
forward_op_idx, gradient_ops, g_output, g_input)
1034
for name, grad in zip(forward_op.input, g_input):
1039
if grad is not None or \
1040
name not in input_to_grad or \
1041
name in list(forward_op.output):
1042
new_input_to_grad[name] = grad
1044
return new_input_to_grad, gradient_ops
1046
def GetBackwardPass(self, ys):
1047
"""Gets the backward pass that computes the derivatives of given blobs.
1050
ys: a list or a dictionary specifying what blobs we want to compute
1051
derivatives of. If the input is a list, we will automatically
1052
generate their gradients with all-one values; if the input is a
1053
dictionary, for any dictionary entries that are not None, we will
1054
take the corresponding blobs as their gradients; for all those
1055
that are None, we will auto-fill them with 1.
1057
if isinstance(ys, list):
1058
ys = dict((y, None) for y in ys)
1059
elif not isinstance(ys, dict):
1060
raise TypeError("ys should either be a list or a dict.")
1065
self.gradient_frontier[y] = self.frontier[y]
1066
self.input_usages[str(y)][self.frontier[str(y)]].append(
1069
all_input_to_grad, all_gradient_ops = self._GetInitGradients(ys)
1076
for forward_op_idx in reversed(range(len(self.ssa))):
1077
input_to_grad, gradient_ops = self._GenerateGradientsForForwardOp(
1078
forward_op_idx, all_input_to_grad)
1079
all_input_to_grad.update(input_to_grad)
1080
all_gradient_ops += gradient_ops
1083
additional_sum_ops, grad_map = self.DoGradientAccumulation(
1088
all_input_to_grad.update(grad_map)
1089
all_gradient_ops += additional_sum_ops
1095
all_input_to_grad_out = {}
1096
for key, val in all_input_to_grad.items():
1098
if isinstance(val, (bytes, str)):
1099
grad_out = BlobReference(val)
1101
grad_out = GradientSlice(BlobReference(val[0]),
1102
BlobReference(val[1]))
1103
all_input_to_grad_out[BlobReference(key)] = grad_out
1104
return all_gradient_ops, all_input_to_grad_out
1107
class GradientRegistry:
1108
"""GradientRegistry holds the mapping from operators to their gradients."""
1109
gradient_registry_ = {}
1112
def RegisterGradient(cls, op_type):
1113
"""A decorator for registering gradient mappings."""
1116
cls.gradient_registry_[op_type] = func
1122
def _GetGradientForOpCC(cls, op_def, g_output):
1124
def from_untyped(grad):
1126
w = C.GradientWrapper()
1130
(indices, values) = grad
1131
w = C.GradientWrapper()
1134
assert w.is_sparse()
1137
w = C.GradientWrapper()
1142
g_output = [from_untyped(grad) for grad in g_output]
1143
grad_defs_str, g_input = C.get_gradient_defs(
1144
op_def.SerializeToString(), g_output)
1146
def to_untyped(grad_wrapper):
1147
if grad_wrapper.is_empty():
1149
if grad_wrapper.is_sparse():
1150
return GradientSlice(grad_wrapper.indices, grad_wrapper.values)
1151
assert grad_wrapper.is_dense()
1152
return grad_wrapper.dense
1154
g_input = [to_untyped(grad_wrapper) for grad_wrapper in g_input]
1156
for grad_def_str in grad_defs_str:
1157
grad_def = caffe2_pb2.OperatorDef()
1158
grad_def.ParseFromString(grad_def_str)
1159
grad_defs.append(grad_def)
1160
return grad_defs, g_input
1163
def GetGradientForOp(cls, op, g_output):
1165
gradient_ops, g_input = cls._GetGradientForOpCC(op, g_output)
1166
except Exception as e:
1168
if op.type in cls.gradient_registry_:
1169
gradient_ops, g_input = cls.gradient_registry_[op.type](
1174
"Exception when creating gradient for [{}]:{}.\nOp: \n{}".
1175
format(op.type, e, str(op))
1178
if gradient_ops is None:
1180
if type(gradient_ops) is not list:
1181
gradient_ops = [gradient_ops]
1182
return gradient_ops, g_input
1185
def GetBackwardPass(cls, operators, ys, ys_generate_gradient=False):
1186
"""Gets the backward pass for the list of operators.
1189
operators: a list of operators constituting the forward pass.
1190
ys: a list or a dictionary specifying what blobs we want to compute
1191
derivatives of. If the input is a list, we will automatically
1192
generate their gradients with all-one values; if the input is a
1193
dictionary, for any dictionary entries that are not None, we'll
1194
take the corresponding blobs as their gradients; for all those
1195
that are None, we will auto-fill them with 1.
1197
gradient_ops: a list of gradient operators to run.
1198
all_input_to_grads: a map from input to their corresponding
1202
return ir.GetBackwardPass(ys)
1205
GradientRegistry.RegisterGradient('Do')(gen_do_gradient)
1206
GradientRegistry.RegisterGradient('If')(gen_if_gradient)
1207
GradientRegistry.RegisterGradient('While')(gen_while_gradient)
1210
def get_ssa(net, blob_versions=None):
1212
Given a net, return a structure containing the version of each input and
1213
output blob used by each operator.
1216
net: either a Net or a NetDef
1217
blob_versions: (optional) map with current version number for given
1218
blob names. If not provided or blob not found, start
1221
Tuple (ssa, blob_versions)
1222
ssa: list of tuples (versioned_inputs, versioned_outputs)
1223
for each op in the net. A versioned input is a tuple
1224
(blob_name, version).
1225
blob_versions: updated map with latest version of each blob found in
1228
proto = net.Proto() if isinstance(net, Net) else net
1229
assert isinstance(proto, caffe2_pb2.NetDef)
1230
if blob_versions is None:
1232
if isinstance(net, list):
1233
return [get_ssa(n, blob_versions) for n in net], blob_versions
1234
for i in proto.external_input:
1235
if i not in blob_versions:
1236
blob_versions[str(i)] = 0
1239
if not proto.external_input:
1241
if i not in blob_versions:
1242
blob_versions[i] = 0
1243
inputs = [(str(i), blob_versions.get(str(i), 0)) for i in op.input]
1245
blob_versions[str(o)] = blob_versions.get(str(o), 0) + 1
1246
outputs = [(str(o), blob_versions[str(o)]) for o in op.output]
1247
ssa.append((inputs, outputs))
1248
return ssa, blob_versions
1251
def get_undefined_blobs(ssa):
1253
Given a ssa in the format produced by get_ssa(), return a set of blobs that
1254
are used before they are defined, which corresponds to inputs at version 0.
1257
for inputs, _outputs in ssa:
1258
undef_blobs |= set(name for (name, ver) in inputs if ver == 0)
1262
def get_output_producers(ssa):
1264
Given a ssa in the format produced by get_ssa(), returns a map from
1265
versioned blob into the operator index that produces that version of
1266
the blob. A versioned blob is a tuple (blob_name, version).
1269
for i, (_inputs, outputs) in enumerate(ssa):
1275
def get_op_ids_in_path(ssa, blob_versions, inputs, outputs):
1277
Given a ssa and blob_versions as produced by get_ssa(), returns the list
1278
of op indices that are necessary in order to generate the blobs in
1279
`outputs`, given blobs in `inputs`.
1280
Consider that the `inputs` are given in their latest version.
1282
inputs_set = set((str(i), blob_versions[str(i)]) for i in inputs)
1283
producers = get_output_producers(ssa)
1284
queue = [(str(o), blob_versions[str(o)]) for o in outputs]
1286
while len(queue) > 0:
1288
if (o not in inputs_set) and (o in producers):
1289
op_id = producers[o]
1290
if op_id not in used_op_ids:
1291
used_op_ids |= {op_id}
1292
inputs, _ = ssa[op_id]
1293
queue.extend(inputs)
1294
return sorted(used_op_ids)
1297
def recurrent_network_op_remap(op, prefix, blob_remap):
1301
op : Caffe2 operator (RecurrentNetworkOp or RecurrentNetworkGradientOp).
1302
prefix: this argument is not used in this function, just for legacy support.
1303
blob_remap : Dictionary that represents the map from old blob name to new.
1305
Updates blob names in arguments of RecurrentNetworkOp and
1306
RecurrentNetworkGradientOp to conform to cloned input and output of both
1307
operators and also makes sure names of locally generated blobs in arguments
1308
have the same prefix as the input and output of the operators.
1311
def get_remapped_str(blob_str):
1312
if isinstance(blob_str, bytes):
1313
blob_str = blob_str.decode('utf-8')
1314
return blob_remap.get(blob_str, blob_str).encode('utf-8')
1316
for argument in op.arg:
1317
if len(argument.strings) > 0:
1318
for i in range(len(argument.strings)):
1319
argument.strings[i] = get_remapped_str(argument.strings[i])
1320
elif argument.name == 'timestep':
1321
argument.s = get_remapped_str(argument.s)
1322
elif argument.name.endswith('step_net'):
1324
remap_proto(argument, blob_remap)
1327
def control_op_remap(op, prefix, blob_remap):
1329
if op.type == "If" or op.type == "AsyncIf":
1330
net_arg_names = ['then_net', 'else_net']
1332
net_arg_names = ['loop_net', 'cond_net']
1333
for argument in op.arg:
1334
if argument.name in net_arg_names:
1335
assert argument.n, \
1336
"Expected non empty net in " + op.type + "'s " + argument.name + " argument"
1337
subnet = Net(argument.n)
1338
remapped_subnet = subnet.Clone(
1339
name=(subnet._net.name if subnet._net.name else '') + '_remapped',
1340
blob_remap=blob_remap)
1341
argument.n.CopyFrom(remapped_subnet.Proto())
1344
DEFAULT_REMAP_FUNCS = {
1345
'RecurrentNetwork': recurrent_network_op_remap,
1346
'RecurrentNetworkGradient': recurrent_network_op_remap,
1347
'If': control_op_remap,
1348
'While': control_op_remap,
1349
'AsyncIf': control_op_remap,
1353
def remap_proto(argument, blob_remap):
1354
subnet = Net(argument.n)
1356
cloned_sub_net = subnet.Clone(
1361
argument.n.CopyFrom(cloned_sub_net.Proto())
1364
def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None,
1367
Clone the given Net, binding its input schema to the given `inputs` record.
1368
Blob names defined by the net are prepended with the given `prefix`.
1371
net: the net to clone
1372
name: the name of the new net
1373
prefix: the prefix to append to local blobs
1374
blob_remap: (optional) dict with additional blob name remapping.
1375
inputs: (optional) input record that will provide actual input
1376
values for the cloned net. Must be compatible with the
1377
net's input schema or be a strict superset of it
1378
keep_schema: by default (True), the original schema will be kept and
1379
remapped accordingly. otherwise, the schema will be set as
1380
inputs or left empty if inputs is not given.
1382
Tuple (cloned_net, blob_remap)
1383
clone_net: the cloned Net
1384
blob_remap: a map from original blob names into remapped blob names
1386
from caffe2.python import schema
1387
assert isinstance(net, Net)
1388
if blob_remap is None:
1390
if inputs is not None:
1391
assert isinstance(inputs, schema.Field)
1392
original = net.input_record()
1393
assert original is not None
1395
diff = set(original.field_names()) - set(inputs.field_names())
1396
assert len(diff) == 0, (
1397
"Schemas don't match, extra fields {diff} found in the net {name}. "
1398
"original: {original}; inputs: {inputs}"
1400
diff=diff, name=net.Name(), original=original.field_names(),
1401
inputs=inputs.field_names()
1404
original_mapping = dict(zip(original.field_names(),
1405
original.field_blobs()))
1406
for fn, fb in zip(inputs.field_names(), inputs.field_blobs()):
1407
if fn in original_mapping:
1408
blob_remap[str(original_mapping[fn])] = str(fb)
1410
ssa, blob_versions = get_ssa(proto)
1411
undef_blobs = get_undefined_blobs(ssa)
1413
for blob in blob_versions.keys():
1414
if blob in blob_remap:
1416
elif blob in undef_blobs:
1417
blob_remap[blob] = blob
1419
blob_remap[blob] = prefix + blob
1420
cloned_net = net.Clone(name, blob_remap, keep_schema=keep_schema)
1421
if not keep_schema and inputs:
1422
cloned_net.set_input_record(inputs)
1423
return cloned_net, blob_remap
1426
def _get_blob_ref(blob_name_or_ref):
1428
blob_name_or_ref if isinstance(input, BlobReference)
1429
else BlobReference(blob_name_or_ref)
1433
def _recover_record_by_prefix(names, prefix=''):
1435
Tries to recover record by taking a subset of blob names with
1436
a given prefix name and interpreting them as schema column names
1438
from caffe2.python import schema
1439
column_names = [name[len(prefix):] for name in names
1440
if name.startswith(prefix)]
1441
if not column_names:
1443
return schema.from_column_list(
1445
col_blobs=[_get_blob_ref(prefix + name) for name in column_names])
1449
_net_names_used_counters: Dict[str, int] = {}
1450
_net_names_used: Set[str] = set()
1451
operator_registry_ = {}
1454
def current_prefix():
1455
from caffe2.python.net_builder import NetBuilder
1456
builder = NetBuilder.current(required=False)
1457
return builder.name if builder else ''
1460
def _reset_used_names() -> None:
1461
Net._net_names_used_counters = {}
1462
Net._net_names_used = set()
1465
def _get_next_net_name(basename):
1466
basename = "/".join(x for x in [Net.current_prefix(), basename] if x)
1467
idx = Net._net_names_used_counters.get(basename, 0)
1469
name := basename if idx == 0 else f"{basename}_{idx}"
1470
) in Net._net_names_used:
1472
Net._net_names_used_counters[basename] = idx + 1
1473
Net._net_names_used.add(name)
1476
def __init__(self, name_or_proto, inplace=False):
1480
name_or_proto: If a NetDef is provided, clone it (or take ownership,
1481
depending on the value of `inplace`). Otherwise,
1482
create an empty net with the given name.
1483
inplace: If a NetDef is provided, take ownership when `inplace` is True;
1484
otherwise, clone it.
1486
self._input_record = None
1487
self._output_record = None
1490
self._registered_blob_names = set()
1491
self._recreate_lookup_tables = False
1492
self._op_outputs = set()
1493
self._external_input_map = set()
1494
self._attr_dict = defaultdict(list)
1495
if type(name_or_proto) is caffe2_pb2.NetDef:
1496
proto = name_or_proto
1502
self._net = caffe2_pb2.NetDef()
1503
self._net.CopyFrom(proto)
1505
existing_outputs = [list(op.output) for op in self._net.op]
1507
self._external_input_map.update(list(self._net.external_input))
1510
existing_names = set()
1511
for op in self._net.op:
1512
existing_names.update(list(op.input))
1513
for output in existing_outputs:
1514
existing_names.update(output)
1516
for outs in existing_outputs:
1517
self._op_outputs.update(outs)
1519
prefix_len = len(self._net.name + '_blob_')
1520
autogen_indices = []
1521
for s in existing_names:
1522
if s.startswith(self._net.name + '_blob_'):
1524
autogen_indices.append(int(s[prefix_len]))
1527
if len(autogen_indices):
1528
self._next_name_index = max(autogen_indices) + 1
1530
self._next_name_index = 0
1531
name = self._net.name
1533
name = name_or_proto
1534
self._net = caffe2_pb2.NetDef()
1535
self._next_name_index = 0
1538
self._net.name = Net._get_next_net_name(name)
1541
self._next_blob_name_ids = {}
1544
def AppendNet(self, net, device_option=None):
1545
assert isinstance(net, Net)
1546
for i in net.Proto().external_input:
1548
i not in self.Proto().external_input and
1549
i not in self._op_outputs
1551
self.Proto().external_input.append(i)
1553
self.Proto().external_output.extend(
1555
o for o in net.Proto().external_output
1556
if o not in self.Proto().external_output
1559
ops = net.Proto().op
1560
if device_option is not None:
1561
ops = [copy.deepcopy(op) for op in ops]
1563
op.device_option.CopyFrom(device_option)
1565
if op.type == "RecurrentNetwork":
1567
if arg.name.endswith('step_net'):
1568
for step_op in arg.n.op:
1569
step_op.device_option.CopyFrom(device_option)
1571
self._ExtendOps(ops)
1574
def LogInfo(self, *msg_or_blobs):
1575
for msg_or_blob in msg_or_blobs:
1576
if not isinstance(msg_or_blob, BlobReference):
1577
blob = self.GivenTensorStringFill(
1578
[], self.NextName('log'),
1579
shape=[], values=[msg_or_blob])
1582
self.Print(blob, [])
1584
def add_attribute(self, name, obj):
1586
Add `obj` to the list of attributes in this net under the given `name`.
1587
Attributes are user-defined objects and have no pre-defined semantics.
1589
self._attr_dict[name].append(obj)
1591
def get_attributes(self, name):
1593
Returns the list of attributes in this net for a given `name`.
1594
Attributes are user-defined objects added with `add_attribute'.
1596
return self._attr_dict.get(name, [])
1598
def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False):
1600
Adds a random seed to each op in the net.
1601
If sequence_seed is set, the i-th op has rand_seed=`seed + i`
1602
If seed_on_op_def is set, the op rand_seed=hash(str(op))
1603
sequence_seed and seed_on_op_def cannot be both set to True.
1605
assert not (sequence_seed and seed_on_op_def), (
1606
'sequence_seed and seed_on_op_def cannot be both set to True.')
1607
for i, op in enumerate(self.Proto().op):
1609
curr_seed = seed + i
1610
elif seed_on_op_def:
1611
curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
1614
op.device_option.random_seed = curr_seed
1617
return self._net.name
1622
def Const(self, array, blob_out=None, dtype=None):
1623
if isinstance(array, bool):
1624
return self.ConstantFill(
1627
dtype=DataType.BOOL,
1631
array = np.array(array)
1633
array = np.array(array, dtype=dtype)
1635
def do_set(operator):
1640
values=array.flatten().tolist())
1642
if array.dtype == np.int32:
1643
return do_set(self.GivenTensorIntFill)
1644
elif array.dtype == np.int64:
1645
return do_set(self.GivenTensorInt64Fill)
1646
elif array.dtype == str:
1647
return do_set(self.GivenTensorStringFill)
1648
elif array.dtype == bool:
1649
return do_set(self.GivenTensorBoolFill)
1651
return do_set(self.GivenTensorFill)
1653
def BlobIsDefined(self, blob):
1655
Returns true if the given BlobReference is produced as output of
1656
an operator in this net, or if it is provided as an external input.
1658
if self._recreate_lookup_tables:
1659
self._RecreateLookupTables()
1661
return (name in self._op_outputs) or (name in self._external_input_map)
1663
def UsesBlob(self, blob):
1665
Returns true iff the given BlobReference is used by any operator
1666
or this net, or if it is one of the external inputs of the net.
1668
blob_name = str(blob)
1669
for op in self._net.op:
1670
for input in op.input:
1671
if input == blob_name:
1673
return blob_name in self._external_input_map
1675
def UsedBlobNames(self):
1677
Returns a set of blob names used in the net
1680
for op in self._net.op:
1681
blob_names |= set(op.input)
1682
blob_names |= set(op.output)
1683
if self._net.external_input:
1684
blob_names |= set(self._net.external_input)
1685
if self._net.external_output:
1686
blob_names |= set(self._net.external_output)
1689
def GetBlobRef(self, blob_name):
1691
Given the name of a blob produced by this net, return a BlobReference
1692
to it. If the blob is not produced by any op in this net,
1695
blob_name = str(blob_name)
1696
if not self.BlobIsDefined(blob_name):
1697
raise KeyError('Net does not define blob %s' % blob_name)
1698
return BlobReference(blob_name, self)
1707
update_external_list=False,
1712
name: name of the cloned net
1713
blob_remap: optional map with list of blob names to replace
1714
op_id_mask: optional list of operator indices to include in
1715
the cloned net. If not provided, all ops are included.
1717
orig_remap_funcs = {} if remap_funcs is None else remap_funcs
1721
remap_funcs = DEFAULT_REMAP_FUNCS.copy()
1722
remap_funcs.update(orig_remap_funcs)
1724
new_proto = caffe2_pb2.NetDef()
1725
new_proto.CopyFrom(proto)
1726
new_proto.name = name
1728
if blob_remap is None:
1730
if op_id_mask is None:
1731
op_id_mask = list(range(0, len(proto.op)))
1733
def get_remapped_str(blob):
1734
blob_str = str(blob)
1735
return str(blob_remap.get(blob_str, blob_str))
1737
def remap_list(proto_list):
1738
new_list = [get_remapped_str(b) for b in proto_list]
1740
proto_list.extend(new_list)
1743
new_op = caffe2_pb2.OperatorDef()
1745
remap_list(new_op.input)
1746
remap_list(new_op.output)
1747
if new_op.type in remap_funcs:
1748
remap_funcs[new_op.type](
1750
(name + '/') if name else '',
1756
new_proto.op.extend([remap_op(proto.op[op_id]) for op_id in op_id_mask])
1757
remap_list(new_proto.external_input)
1758
remap_list(new_proto.external_output)
1759
new_net = Net(new_proto)
1762
from caffe2.python import schema
1763
if self._input_record:
1764
new_net._input_record = schema.from_blob_list(
1767
BlobReference(get_remapped_str(blob), net=new_net)
1768
for blob in self._input_record.field_blobs()
1771
if self._output_record:
1772
new_net._output_record = schema.from_blob_list(
1773
self._output_record,
1775
BlobReference(get_remapped_str(blob), net=new_net)
1776
for blob in self._output_record.field_blobs()
1780
new_net._attr_dict.update(self._attr_dict)
1781
if update_external_list:
1783
existing_outputs = set()
1784
used_outputs = set()
1785
del new_net.Proto().external_input[:]
1786
del new_net.Proto().external_output[:]
1787
for op in new_net.Proto().op:
1789
if ib not in existing_outputs:
1790
new_net.Proto().external_input.extend([ib])
1792
used_outputs.add(ib)
1793
for ob in op.output:
1794
existing_outputs.add(ob)
1796
for ob in existing_outputs:
1797
if ob not in used_outputs:
1798
new_net.Proto().external_output.extend([ob])
1801
def ClonePartial(self, name, inputs, outputs, remap_funcs=None):
1803
Clone this net, including only ops that are necessary in order to
1804
compute `outputs` given `inputs`. Return references to the cloned
1805
outputs. Internal blobs (blobs that are produced and consumed inside
1806
the net but not used as outputs) will be remapped to avoid name
1810
name: the name of the cloned net
1811
inputs: map where the keys correspond to BlobReferences in the
1812
original net, and the values correspond to external inputs
1813
in the partially cloned net. If `inputs` is a list, don't
1815
outputs: outputs to be produced by the cloned net.
1818
Tuple (new_net, new_outputs)
1819
new_net: a new Net object.
1820
new_outputs: list of BlobReferences corresponding to the
1821
outputs produced by new_net.
1823
input_is_pair_list = isinstance(inputs, list) and all(
1824
isinstance(i, tuple) and len(i) == 2 for i in inputs)
1826
inputs if isinstance(inputs, (dict, OrderedDict)) else
1827
OrderedDict(inputs) if input_is_pair_list else
1828
OrderedDict(zip(inputs, inputs)))
1829
for output in outputs:
1830
assert self.BlobIsDefined(output), "{} is not defined".format(output)
1831
input_names = {str(k): str(v) for k, v in inputs.items()}
1832
output_names = [str(o) for o in outputs]
1834
blob_versions = {str(i): 0 for i in inputs}
1835
ssa, blob_versions = get_ssa(proto, blob_versions)
1836
used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
1837
disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs)
1838
assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
1839
'Cannot partially clone net: some of the ops required would ' +
1840
'generate the given input.')
1842
sub_ssa = [op for i, op in enumerate(ssa) if i in used_op_ids]
1843
undef_blobs = get_undefined_blobs(sub_ssa) - set(input_names.keys())
1844
prefix = (name + '/') if name else ''
1846
def remap(blob_name):
1847
if blob_name in input_names:
1848
return input_names[blob_name]
1849
elif blob_name in undef_blobs:
1852
return prefix + blob_name
1854
blob_mapping = {b: remap(b) for b in blob_versions.keys()}
1855
new_net = self.Clone(name, blob_mapping, used_op_ids, remap_funcs)
1857
blob_mapping[i] for i in input_names.keys()] + list(undef_blobs)
1858
new_out = [blob_mapping[o] for o in output_names]
1859
del new_net.Proto().external_input[:]
1860
new_net.Proto().external_input.extend(new_in)
1861
new_net._external_input_map = set(list(new_in))
1862
del new_net.Proto().external_output[:]
1863
new_net.Proto().external_output.extend(new_out)
1864
return new_net, [new_net.GetBlobRef(o) for o in new_out]
1867
self._InvalidateLookupTables()
1870
def insert_op_at_idx(self, op, op_idx):
1871
r""" inserting operator at index. Will update external blob list.
1874
temp_ops = self.Proto().op[op_idx:]
1875
del self.Proto().op[op_idx:]
1876
self.Proto().op.extend([op])
1877
self.Proto().op.extend(temp_ops)
1878
self.external_outputs.extend(op.output)
1879
self.external_inputs.extend(op.input)
1881
def reroute_tensor(self, tensor, new_producer, can_modify=None):
1882
r""" reroute tensor to new_producer. And feed new tensor to consumers
1883
and interseciton with can_modify if provided.
1885
tensor: str or blob_reference the tensor to reroute
1886
new_producer: an op takes in tensor gives new_tesnor
1887
can_modify: a list/set of operators that consumes tensor and can be
1891
reroute_cnt: how many consumer op has been changed
1893
Note: assume no inplace blob in net
1895
def _find_tensor_input_op(tensor):
1896
if tensor in self.external_inputs:
1899
assert tensor in new_producer.input, \
1900
"new producer {} is not taking in {}".format(
1901
new_producer.type, tensor)
1904
for index, op in enumerate(self.Proto().op):
1917
op_idx = max(_find_tensor_input_op(t) for t in new_producer.input)
1918
self.insert_op_at_idx(new_producer, op_idx + 1)
1919
new_tensor = new_producer.output[0]
1921
if tensor in self.external_outputs:
1922
new_list = [new_tensor if b == tensor else b for b in self.external_outputs]
1923
del self.Proto().external_output[:]
1924
self.Proto().external_output.extend(new_list)
1929
for op in self.Proto().op:
1930
if op in can_modify:
1931
remap_input(op, {tensor: new_tensor})
1932
reroute_cnt = reroute_cnt + 1
1935
def PopulateProtoWithFileName(self):
1936
net_tb = workspace.operator_tracebacks.get(self.Name(), None)
1937
if net_tb is not None:
1938
for idx, op in enumerate(self.Proto().op):
1940
op.name = ':'.join(map(str, net_tb[idx][0]))
1942
def NextScopedBlob(self, prefix='unnamed'):
1943
"""Return the blob that has not been defined or registered in the
1944
current net. It returns `ScopedBlobReference(prefix)`, if it's valid,
1945
otherwise `ScopedBlobReference(prefix) + '_auto_' + ?`. Different calls
1946
is guaranteed to return blob with different names.
1948
output_blob_base = ScopedName(prefix)
1949
return self.NextBlob(output_blob_base)
1951
def NextBlob(self, prefix='unnamed'):
1952
"""Return the blob that has not been defined or registered in the
1953
current net. It returns `BlobReference(prefix)`, if it's valid,
1954
otherwise `BlobReference(prefix) + '_auto_' + ?`. Different calls
1955
is guaranteed to return blob with different names."""
1956
output_blob_base = BlobReference(prefix)
1957
output_blob = output_blob_base
1959
while str(output_blob) in self._registered_blob_names or (
1960
self.BlobIsDefined(output_blob)):
1961
output_blob = output_blob_base + '_auto_' + str(index)
1964
self._registered_blob_names.add(str(output_blob))
1967
def NextName(self, prefix=None, output_id=None):
1968
"""Returns the next name to be used, if you do not want to explicitly
1969
name your blob. [Deprecated, use NextBlob, NextScopedBlob instead]"""
1971
output_name_base = self._net.name + '/' + prefix
1972
output_name = output_name_base
1973
if output_id is not None:
1974
output_name += ':' + str(output_id)
1976
index = self._next_blob_name_ids.get(key, 2)
1977
while self.BlobIsDefined(str(ScopedBlobReference(output_name))):
1978
output_name = output_name_base + '_' + str(index)
1979
if output_id is not None:
1980
output_name += ':' + str(output_id)
1982
self._next_blob_name_ids[key] = index
1984
output_name = self._net.name + '_blob_' + str(self._next_name_index)
1985
self._next_name_index += 1
1986
return str(output_name)
1988
def _ExtendOps(self, new_ops):
1989
self._net.op.extend(new_ops)
1991
self._op_outputs.update([str(o) for o in op.output])
1993
def _CheckLookupTables(self):
1995
Called from unit tests to validate the internal lookup tables
1996
match the protobuf contents.
1998
test_op_outputs = set()
1999
for op in self._net.op:
2001
test_op_outputs.add(o)
2003
test_external_inp = set()
2004
for inp in self._net.external_input:
2005
test_external_inp.add(inp)
2007
assert test_op_outputs.difference(self._op_outputs) == set()
2008
assert test_external_inp.difference(self._external_input_map) == set()
2010
def _InvalidateLookupTables(self):
2011
self._recreate_lookup_tables = True
2013
def _RecreateLookupTables(self):
2014
self._op_outputs = {o for op in self._net.op for o in op.output}
2015
self._external_input_map = {inp for inp in self._net.external_input}
2016
self._recreate_lookup_tables = False
2018
def AddGradientOperators(self, ys, skip=0):
2019
"""Add the gradient for operators in the net.
2022
ys: a list or a dictionary specifying what blobs we want to compute
2023
derivatives of. If the input is a list, we will automatically
2024
generate their gradients with all-one values; if the input is a
2025
dictionary, for any dictionary entries that are not None, we will
2026
take the corresponding blobs as their gradients; for all those
2027
that are None, we will auto-fill them with 1.
2028
skip: skips the first n operators. This is provided mainly because a
2029
lot of nets may use the first few operators for data generation
2030
like stuff which really do not need to have gradients.
2033
returns a map from the blob name in the input network to a blob
2034
containing gradient or a GradientSlice in case of sparse gradient
2036
Currently, this is hard-coded for float operators if there are branches
2037
(i.e. a blob is used as input to multiple operators). This is because
2038
the gradient accumulation (Sum) is float only right now.
2041
grad_ops, input_to_grad = GradientRegistry.GetBackwardPass(
2042
self._net.op[skip:], ys)
2046
if workspace.IsImmediate():
2048
workspace.RunOperatorImmediate(op)
2049
self._ExtendOps(grad_ops)
2050
return input_to_grad
2052
def AddArgument(self, arg_name, arg_value):
2053
self._net.arg.extend([utils.MakeArgument(arg_name, arg_value)])
2055
def AddExternalInput(self, *inputs):
2056
assert len(inputs) > 0
2058
input_name_set = set()
2059
for input in inputs:
2060
input_name = str(input)
2062
input_name not in self._external_input_map
2063
and input_name not in input_name_set
2064
), ("Net already contains an input named %s" % input_name)
2065
input_name_set.add(input_name)
2066
for input in inputs:
2067
input_name = str(input)
2068
self._net.external_input.extend([input_name])
2069
self._external_input_map.update([input_name])
2070
refs.append(_get_blob_ref(input_name))
2072
return refs[0] if len(refs) == 1 else refs
2074
def AddExternalOutput(self, *outputs):
2075
for output in outputs:
2076
assert isinstance(output, BlobReference)
2077
assert self.BlobIsDefined(output), "{} is not defined".format(output)
2078
for output in outputs:
2079
self.Proto().external_output.extend([str(output)])
2081
def AddScopedExternalInputs(self, *inputs):
2082
res = self.AddExternalInput(
2083
* [ScopedBlobReference(b) for b in inputs]
2085
if not isinstance(res, list):
2089
def AddScopedExternalOutputs(self, *outputs):
2090
return self.AddExternalOutput(
2091
* [ScopedBlobReference(b) for b in outputs]
2095
def AddObserver(self, observer_type):
2096
return C.add_observer_to_net(self._net.name, observer_type)
2098
def RemoveObserver(self, observer):
2099
C.remove_observer_from_net(self._net.name, observer)
2101
def NumObservers(self):
2102
return C.num_observers_on_net(self._net.name)
2105
def external_inputs(self):
2106
return [_get_blob_ref(x) for x in self._net.external_input]
2109
def external_outputs(self):
2110
return [_get_blob_ref(x) for x in self._net.external_output]
2112
def set_input_record(self, input_record):
2113
from caffe2.python import schema
2114
assert self._input_record is None or (input_record.has_blobs() and
2115
set(input_record.field_blobs()) ==
2116
set(self._input_record.field_blobs())), (
2117
'Input schema cannot be reset')
2118
if not input_record.has_blobs():
2119
with NameScope(self.Name()):
2120
self._input_record = schema.NewRecord(self, input_record)
2122
self._input_record = input_record
2124
for blob in self._input_record.field_blobs():
2125
if not self.is_external_input(blob):
2126
self.AddExternalInput(blob)
2127
return self._input_record
2129
def recover_input_record_by_prefix(self, prefix):
2131
Tries to recover input record by taking a subset of external_inputs with
2132
a given prefix name and interpreting them as schema column names
2134
record = _recover_record_by_prefix(self._net.external_input, prefix)
2136
self.set_input_record(record)
2138
def set_output_record(self, record):
2139
assert self._output_record is None or (record.has_blobs() and
2140
set(record.field_blobs()) ==
2141
set(self._output_record.field_blobs())), (
2142
'Output schema cannot be reset')
2143
for blob in record.field_blobs():
2144
assert self.BlobIsDefined(blob), "{} is not defined in net {}".format(
2148
for blob in record.field_blobs():
2149
if blob not in self.external_outputs:
2150
self.AddExternalOutput(blob)
2151
self._output_record = record
2153
def recover_output_record_by_prefix(self, prefix):
2155
Tries to recover out record by taking a subset of external_outputs with
2156
a given prefix name and interpreting them as schema column names
2158
record = _recover_record_by_prefix(self._net.external_output, prefix)
2160
self.set_output_record(record)
2162
def AppendOutputRecordField(self, field_name, record):
2163
from caffe2.python import schema
2164
assert self._output_record is not None, (
2165
'Tried to append to missing output record'
2167
for blob in record.field_blobs():
2168
assert self.BlobIsDefined(blob), "{} is not defined".format(blob)
2169
for blob in record.field_blobs():
2170
self.AddExternalOutput(blob)
2171
self._output_record = self._output_record + schema.Struct(
2172
(field_name, record)
2175
def input_record(self):
2176
return self._input_record
2178
def output_record(self):
2179
return self._output_record
2181
def AddExternalInputs(self, *inputs):
2182
return self.AddExternalInput(*inputs)
2184
def AddExternalOutputs(self, *outputs):
2185
self.AddExternalOutput(*outputs)
2187
def DeduplicateGradientSlices(self, g, aggregator='sum'):
2188
assert isinstance(g, GradientSlice)
2189
unique, remapping = self.Unique([g.indices], 2, engine='SparseHash')
2190
if aggregator.lower() == 'sum':
2191
new_g = self.UnsortedSegmentSum([g.values, remapping], 1)
2192
elif aggregator.lower() == 'mean':
2193
new_g = self.UnsortedSegmentMean([g.values, remapping], 1)
2195
raise ValueError('{} is not supported'.format(aggregator))
2196
return GradientSlice(indices=unique, values=new_g)
2199
def _RunAllOnGPU(net, gpu_id=0, use_cudnn=False):
2200
device_option = caffe2_pb2.DeviceOption()
2201
device_option.device_type = workspace.GpuDeviceType
2202
device_option.device_id = gpu_id
2203
net.device_option.CopyFrom(device_option)
2209
if op.type != "RecurrentNetwork":
2212
if arg.name == "step_net":
2213
Net._RunAllOnGPU(arg.n, gpu_id, use_cudnn)
2215
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
2216
"""A convenient function to run everything on the GPU."""
2217
self._RunAllOnGPU(self._net, gpu_id, use_cudnn)
2221
def RunAllOnMKL(self):
2222
"""A convenient function to run everything using MKLDNN."""
2223
device_option = caffe2_pb2.DeviceOption()
2224
device_option.device_type = caffe2_pb2.MKLDNN
2225
self._net.device_option.CopyFrom(device_option)
2227
def RunAllOnIDEEP(self):
2228
"""A convenient function to run everything using IDEEP."""
2229
device_option = caffe2_pb2.DeviceOption()
2230
device_option.device_type = caffe2_pb2.IDEEP
2231
self._net.device_option.CopyFrom(device_option)
2233
def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, **kwargs):
2234
"""A helper function to create an operator and add it to self.
2236
inputs = _RectifyInputOutput(inputs)
2237
for input in inputs:
2238
if not self.BlobIsDefined(input):
2239
assert input.Net() != self
2240
self.AddExternalInput(input)
2244
outputs = self.NextName(prefix=op_type)
2245
elif type(outputs) is int:
2249
self.NextName(prefix=op_type, output_id=i)
2250
for i in range(outputs)]
2251
outputs = _RectifyInputOutput(outputs, net=self)
2252
op = CreateOperator(op_type, inputs, outputs, **kwargs)
2253
self._ExtendOps([op])
2255
workspace.operator_tracebacks[self.Name()][
2256
len(self._net.op) - 1] = _extract_stacktrace()
2258
if len(op.output) == 0:
2260
elif len(op.output) == 1:
2261
return BlobReference(op.output[0], self)
2263
return tuple(BlobReference(o, self) for o in op.output)
2265
def __getattr__(self, op_type):
2266
if op_type.startswith('__'):
2267
raise AttributeError('Attribute {} not found.'.format(op_type))
2268
if not IsOperator(op_type) and not IsOperatorWithEngine(op_type, "CUDNN"):
2269
raise AttributeError(
2270
'Method ' + op_type + ' is not a registered operator.' +
2271
' Did you mean: [' +
2272
",".join(workspace.C.nearby_opnames(op_type)) + ']'
2274
return lambda *args, **kwargs: self._CreateAndAddToSelf(
2275
op_type, *args, **kwargs)
2279
additional_methods = [
2281
for op in _REGISTERED_OPERATORS
2282
if '_ENGINE_' not in op]
2283
return sorted(set(chain(
2285
self.__dict__.keys(),
2293
python_func_type=None,
2294
pass_workspace=False,
2295
grad_output_indices=None,
2296
grad_input_indices=None
2299
Registers and returns a python operator.
2301
`f` and `grad_f` can be one of the following:
2302
- a function with signature (inputs, outputs), where inputs and
2303
outputs are a list of CPUTensor objects. This function will be
2304
called from C++ everytime the operator is executed.
2305
- a tuple (func, args, kwargs), here `func` is a callable, args is
2306
an argument list, and kwargs is a dict list. The call:
2307
f = func(*args, kwargs)
2308
will be performed locally at node initialization time, on all of
2309
the nodes of the job, returning `f`, a callable that will be used
2310
as the python operator function to be called during Net execution.
2311
This is to be used when using python operator in a distributed
2312
context, and allows to create and keep local python state across
2313
calls to the operator.
2315
`python_func_type` is a type of an object that constructed as
2316
python_func_type(f) and provides an implementation to forward and
2317
backward functions. Its useful in such a case where users needs
2318
a statefull PythonOp (ex: use autograd for computing grad_f).
2320
If `pass_workspace` is True, the signature is changed to
2321
(inputs, outputs, workspace) where `workspace` is the workspace the op
2322
is going to run on. This is potentially dangerous (as the op can
2323
manipulate the workspace directly), use on your own risk.
2325
If a gradient function is specified (`grad_f`), by default its inputs
2326
will be: (1) all inputs to `f`, (2) followed by all outputs of `f`, (3)
2327
and then all gradient outputs of `f`. The outputs of `grad_f` will be
2328
(by default) all gradient inputs to `f`. If a subset of the gradient
2329
outputs or gradient inputs is desired instead, then the subsets can be
2330
specified by providing `grad_output_indices` and/or `grad_input_indices`
2331
which identify the indices of `f`'s inputs and outputs which have
2334
assert(IsOperator('Python'))
2336
def make_builder(t):
2337
if not isinstance(t, tuple):
2339
assert len(t) == 3, 'Expected builder tuple (func, args, kwargs)'
2340
func, args, kwargs = t
2341
normalized = (func, tuple(args), dict(kwargs))
2342
return pickle.dumps(normalized)
2344
f_builder = make_builder(f)
2345
grad_f_builder = make_builder(grad_f)
2347
assert (not grad_f) or ((not f_builder) == (not grad_f_builder)), (
2348
'A tuple has to be passed to both f and grad_f or neither.')
2352
core_kwargs['pickled_builder'] = f_builder
2353
core_kwargs['pickled_grad_builder'] = grad_f_builder
2354
core_kwargs['pass_workspace'] = pass_workspace
2356
core_kwargs['token'] = _RegisterPythonImpl(
2357
f, grad_f, python_func_type, pass_workspace=pass_workspace)
2359
grad_output_indices = grad_output_indices or []
2360
grad_input_indices = grad_input_indices or []
2361
return lambda *args, **kwargs: self._CreateAndAddToSelf(
2363
grad_output_indices=grad_output_indices,
2364
grad_input_indices=grad_input_indices,
2366
**dict(chain(kwargs.items(), core_kwargs.items()))
2369
def is_external_input(self, blob):
2370
if self._recreate_lookup_tables:
2371
self._RecreateLookupTables()
2374
return name in self._external_input_map
2376
def extend_ops(self, new_ops):
2377
return self._ExtendOps(new_ops)
2380
def remap_input(op, blob_name_remapping):
2381
new_list = [blob_name_remapping.get(b, b) for b in op.input]
2383
op.input.extend(new_list)
2386
def copy_func_between_devices(src, dst):
2387
CPU = caffe2_pb2.CPU
2388
is_src_gpu = IsGPUDeviceType(src.device_type)
2389
is_dst_gpu = IsGPUDeviceType(dst.device_type)
2391
if src.device_type == CPU and dst.device_type == CPU:
2394
if is_src_gpu and is_dst_gpu:
2395
if src.device_id == dst.device_id:
2398
def fun(net, *args, **kw):
2399
with DeviceScope(dst):
2400
return net.Copy(*args, **kw)
2403
if is_src_gpu and dst.device_type == CPU:
2404
def fun(net, *args, **kw):
2405
with DeviceScope(src):
2406
return net.CopyGPUToCPU(*args, **kw)
2409
if src.device_type == CPU and is_dst_gpu:
2410
def fun(net, *args, **kw):
2411
with DeviceScope(dst):
2412
return net.CopyCPUToGPU(*args, **kw)
2415
raise ValueError('Non-supported devices: %s and %s' % (src, dst))
2418
def device_equal(src, dst):
2420
We are using this fucntion instead of == operator because optional-value
2421
comparison between empty device_options and {device_type:0, device_id:0}
2422
returns not equal in some cases.
2424
return src.device_type == dst.device_type and src.device_id == dst.device_id
2427
def update_placeholder_op_output(op, blob_to_device):
2429
Placeholder ops (for e.g. Recv) always runs on CPU. So ensure their
2430
output blobs reside on CPU.
2433
for output in op.output:
2434
if (output in blob_to_device and
2435
blob_to_device[output].device_type != caffe2_pb2.CPU):
2437
outputs.append(output)
2439
op.output.extend(outputs)
2443
def __init__(self, blob, device):
2445
self.device = device
2447
def __eq__(self, other):
2448
return self.blob == other.blob and self.device == other.device
2451
return hash(self.blob + str(self.device))
2454
def InjectCrossDeviceCopies(net, blob_to_device=None, blob_remap=None,
2455
placeHolderOps=None):
2457
Injecting Copy functions between device within a net. Users can provide
2458
a net with part of operators using different device_options. This method
2459
will automatically create a new net with Copy ops inserted in it.
2462
blob_to_device: If not None, it is a map of blobs and their device locations.
2463
blob_remap: If not None, it is a map from a pair (blob, device) to
2464
the name of the blob in the given device. Blobs found in this
2465
map are assumed to be cached and don't need to be copied.
2467
new_net: A new net with CopyCPUToGPU inserted with correct device option
2469
required_external_to_device:
2470
A mapping between unresolved external inputs and their
2471
required device options.
2473
1. every external inputs of this net is already in blob_to_device!
2474
2. if not, this function will use net device option
2475
3. InferOpBlobDevices might fail to get the correct inference for ops like
2476
EnsureCPUOutput that could take in input from multiple places.
2478
new_net = net.Clone(net._net.name + '_cross_device', keep_schema=True)
2479
del new_net._net.op[:]
2480
if blob_to_device is None:
2483
if blob_remap is None:
2486
net_option = net._net.device_option or caffe2_pb2.DeviceOption()
2490
all_remaps = defaultdict(list)
2491
for entry, mapped_blob in blob_remap.items():
2492
all_remaps[entry.blob].append(mapped_blob)
2493
mapped_external_inputs = []
2494
for input in new_net._net.external_input:
2495
mapped_external_inputs.extend(all_remaps.get(input) or [])
2496
new_net._net.external_input.extend(mapped_external_inputs)
2498
for op in net._net.op:
2504
if placeHolderOps is not None and op.type in placeHolderOps:
2505
input_dev, output_dev = InferOpDeviceAsBlobDevices(op)
2507
input_dev, output_dev = InferOpBlobDevices(op)
2509
for dev, input in zip(input_dev, op.input):
2510
assert net.BlobIsDefined(input), \
2511
"input {} should be defined in the net.".format(input)
2512
if input not in blob_to_device:
2513
if net.is_external_input(input):
2514
blob_to_device[input] = net_option
2516
raise AttributeError(
2517
"No device information found for blob {}.".
2521
if not device_equal(blob_to_device[input], dev):
2523
if (RemapEntry(input, dev) in blob_remap and
2524
blob_to_device[blob_remap[RemapEntry(input, dev)]] == dev):
2525
temp_remap[input] = blob_remap[RemapEntry(input, dev)]
2528
copy_func = copy_func_between_devices(
2529
blob_to_device[input], dev
2532
def _gen_new_name(blob, device_option):
2533
CPU = caffe2_pb2.CPU
2534
if device_option.device_type == CPU:
2536
elif IsGPUDeviceType(device_option.device_type):
2537
suffix = '_gpu_' + str(device_option.device_id)
2540
"Unknown device type: {}".
2541
format(device_option.device_type)
2543
return blob + suffix
2545
new_name = _gen_new_name(input, dev)
2546
copy_func(new_net, input, new_name)
2547
blob_remap[RemapEntry(input, dev)] = new_name
2548
temp_remap[input] = new_name
2549
blob_to_device[new_name] = dev
2551
if placeHolderOps is not None and op.type in placeHolderOps:
2552
update_placeholder_op_output(op, blob_to_device)
2557
for dev, output in zip(output_dev, op.output):
2558
if output in blob_to_device and (
2559
output not in op.input and
2560
not device_equal(blob_to_device[output], dev)
2563
"In-place blob: {} is not supported between operators "
2564
"with different device option previous:{} now: {}. "
2565
"Failed op:\n {}".format(
2566
output, blob_to_device[output], dev, op
2569
new_op = caffe2_pb2.OperatorDef()
2572
new_list = [temp_remap.get(b, b) for b in new_op.input]
2574
new_op.input.extend(new_list)
2577
original_inputs = list(op.input)
2578
for i, out in enumerate(new_op.output):
2580
input_idx = original_inputs.index(out)
2581
new_op.output[i] = new_op.input[input_idx]
2585
blob_to_device.update(
2586
{o: d for d, o in zip(output_dev, new_op.output)})
2587
new_net.extend_ops([new_op])
2589
return new_net, blob_to_device
2592
def InjectDeviceCopiesAmongNets(nets, blob_to_device_init=None):
2594
Takes in a list of nets. They usually represent your whole execution graph.
2595
This function will insert cross device copy functions to all nets, and resolve
2596
inter-net external inputs dependencies. This method will insert Copy funcitons if
2597
external inputs of a net is produced on different device than it is required.
2599
nets: a list of nets
2601
new_nets: a list of new nets with device difference solved.
2603
Some notes from wyiming:
2604
1. You MUST pass nets in execution order. e.g. [train_init, train]
2606
assert isinstance(nets, list), \
2607
"nets {} should be a list of nets.".format(str(nets))
2608
assert all(isinstance(net, Net) for net in nets), \
2609
"nets {} should be a list of nets.".format(str(nets))
2611
blob_to_device = blob_to_device_init or {}
2616
new_net, blob_to_device = InjectCrossDeviceCopies(
2618
blob_to_device=blob_to_device,
2619
blob_remap=blob_remap,
2621
new_nets.append(new_net)
2623
return new_nets, blob_to_device
2626
def InjectDeviceCopiesAmongNetsWithoutB2D(nets, blob_to_device_init=None):
2627
new_nets, _ = InjectDeviceCopiesAmongNets(nets, blob_to_device_init)
2631
def get_net_name(netlike):
2632
if isinstance(netlike, Net):
2633
return netlike.Proto().name
2634
elif isinstance(netlike, caffe2_pb2.NetDef):
2640
def output_to_list(op_output):
2642
Ensures that the output of an operator is a list.
2643
Use when an operator has a variable number of outputs, but a list of
2644
outputs is desired even when number of outputs is 1.
2647
op_output: Either a BlobReferenece or an iterable of BlobReferences.
2650
A list of BlobReferences.
2652
assert type(op_output) in (list, tuple, BlobReference)
2655
if isinstance(op_output, BlobReference) else list(op_output))
2658
def _add_net_to_dict(net_dict, net):
2659
name = get_net_name(net)
2660
if name in net_dict:
2661
assert net_dict[name] is None or net == net_dict[name], (
2662
'Different nets with same name: ' + name)
2665
net_dict[name] = net if isinstance(net, Net) else None
2670
_step_names_used = set()
2673
def _get_next_step_name(basename):
2676
while name in ExecutionStep._step_names_used:
2677
name = basename + '_' + str(next_idx)
2679
ExecutionStep._step_names_used |= set([name])
2682
def __init__(self, name, nets=None, num_iter=None):
2683
self._step = caffe2_pb2.ExecutionStep()
2684
self._step.name = name or ExecutionStep._get_next_step_name('step')
2685
self._net_dict = OrderedDict()
2686
self._is_used = False
2688
if nets is not None:
2689
if type(nets) is Net:
2692
if _add_net_to_dict(self._net_dict, net):
2693
self._step.network.extend([get_net_name(net)])
2694
if num_iter is not None:
2695
self._step.num_iter = num_iter
2697
def get_net(self, name):
2698
return self._net_dict[name]
2701
return self._step.name
2704
return self._step.name
2706
def _assert_can_mutate(self):
2707
assert not self._is_used, (
2708
'Cannot mutate a step that has already been added to a plan/step.')
2710
def _notify_is_used(self):
2711
self._is_used = True
2717
return self._step.network is not None and (
2718
len(self._step.network) > 0)
2720
def HasSubsteps(self):
2721
return self._step.substep is not None and (
2722
len(self._step.substep) > 0)
2725
return list(self._net_dict.values())
2728
return self._substeps
2730
def SetIter(self, num_iter):
2731
self._assert_can_mutate()
2732
self._step.num_iter = num_iter
2734
def SetCreateWorkspace(self, create_workspace):
2735
self._assert_can_mutate()
2736
self._step.create_workspace = create_workspace
2738
def SetNumConcurrentInstances(self, num_concurrent_instances):
2739
self._assert_can_mutate()
2740
self._step.num_concurrent_instances = num_concurrent_instances
2742
def SetOnlyOnce(self, only_once):
2743
self._assert_can_mutate()
2744
self._step.only_once = only_once
2746
def SetShouldStopBlob(self, should_stop_blob):
2747
assert isinstance(should_stop_blob, BlobReference), (
2748
"expects BlobReference here, got {}".format(type(should_stop_blob)))
2749
self._assert_can_mutate()
2750
self._step.should_stop_blob = str(should_stop_blob)
2752
def RunEveryMillis(self, interval):
2754
Run this step every interval millisecods, as long as its
2755
siblings are still running. It is guaranteed that, after all
2756
siblings finish, this step will run at least one.
2758
This property is ignored for top-level ExecutionSteps.
2760
self._step.run_every_ms = interval
2762
def SetReportNet(self, report_net, report_interval):
2763
""" DEPRECATED. Use RunEveryMillis instead. """
2764
self._assert_can_mutate()
2765
_add_net_to_dict(self._net_dict, report_net)
2766
self._step.report_net = get_net_name(report_net)
2767
self._step.report_interval = report_interval
2769
def AddSubstep(self, substep):
2770
self._assert_can_mutate()
2771
assert not self.HasNets(), 'Cannot have both network and substeps.'
2772
if isinstance(substep, ExecutionStep):
2773
substep._notify_is_used()
2774
if not substep.HasNets() and not substep.HasSubsteps():
2776
for net in substep.Nets():
2777
_add_net_to_dict(self._net_dict, net)
2778
self._substeps.append(substep)
2779
proto = substep.Proto()
2782
self._step.substep.add().CopyFrom(proto)
2785
def SetConcurrentSubsteps(self, concurrent_substeps):
2786
self._assert_can_mutate()
2787
assert not self.HasNets(), 'Cannot have both network and substeps.'
2788
self._step.concurrent_substeps = concurrent_substeps
2790
def AddNet(self, net):
2791
self._assert_can_mutate()
2792
assert not self.HasSubsteps(), 'Cannot have both network and substeps.'
2793
assert isinstance(net, Net)
2794
_add_net_to_dict(self._net_dict, net)
2795
self._step.network.extend([get_net_name(net)])
2798
def get_all_attributes(self, name):
2800
Return the list of all attributes under the given `name`, present in
2801
all of the nets used in this execution step and its children.
2805
for net in self._net_dict.values()
2806
for attr in net.get_attributes(name)
2810
def create_from_proto(cls, step_proto, net_obj_dict, net_proto_dict):
2812
Create ExecutionStep from ExecutionStep protobuf recursively
2814
assert isinstance(step_proto, caffe2_pb2.ExecutionStep)
2815
assert (len(step_proto.network) > 0 and len(step_proto.substep) == 0) or \
2816
(len(step_proto.network) == 0 and len(step_proto.substep) > 0)
2819
if len(step_proto.substep) > 0:
2820
for substep_proto in step_proto.substep:
2821
steps_or_nets.append(ExecutionStep.create_from_proto(
2822
substep_proto, net_obj_dict, net_proto_dict))
2824
for net_name in step_proto.network:
2825
if net_name not in net_obj_dict:
2826
assert net_name in net_proto_dict
2827
net = Net(net_proto_dict[net_name])
2828
net_obj_dict[net_name] = net
2829
net = net_obj_dict[net_name]
2830
assert isinstance(net, Net)
2831
steps_or_nets.append(net)
2833
num_iter = step_proto.num_iter if step_proto.HasField('num_iter') else None
2834
concurrent_substeps = step_proto.concurrent_substeps if\
2835
step_proto.HasField('concurrent_substeps') else None
2836
should_stop_blob = BlobReference(step_proto.should_stop_blob) if\
2837
step_proto.HasField('should_stop_blob') else None
2838
only_once = step_proto.only_once if\
2839
step_proto.HasField('only_once') else None
2840
num_concurrent_instances = step_proto.num_concurrent_instances if\
2841
step_proto.HasField('num_concurrent_instances') else None
2842
create_workspace = step_proto.create_workspace if\
2843
step_proto.HasField('create_workspace') else None
2844
run_every_ms = step_proto.run_every_ms if\
2845
step_proto.HasField('run_every_ms') else None
2847
return execution_step(
2852
report_interval=None,
2853
concurrent_substeps=concurrent_substeps,
2854
should_stop_blob=should_stop_blob,
2855
only_once=only_once,
2856
num_concurrent_instances=num_concurrent_instances,
2857
create_workspace=create_workspace,
2858
run_every_ms=run_every_ms)
2861
def add_nets_in_order(step, net_list):
2862
proto = step.Proto()
2863
for substep in step.Substeps():
2864
add_nets_in_order(substep, net_list)
2865
for net in proto.network:
2866
if net not in net_list:
2867
net_list.append(net)
2871
if proto.report_net and proto.report_net not in net_list:
2872
net_list.append(proto.report_net)
2877
def __init__(self, name_or_step):
2878
self._plan = caffe2_pb2.PlanDef()
2879
self._net_dict = OrderedDict()
2881
if isinstance(name_or_step, ExecutionStep):
2882
self._plan.name = name_or_step.Name()
2883
self.AddStep(name_or_step)
2884
elif isinstance(name_or_step, basestring):
2885
self._plan.name = name_or_step
2887
raise ValueError('name_or_step must be a string or ExecutionStep')
2890
return self._plan.name
2895
def AddNets(self, nets):
2897
if _add_net_to_dict(self._net_dict, net):
2898
assert isinstance(net, Net)
2899
self._plan.network.add().CopyFrom(net.Proto())
2902
return list(self._net_dict.values())
2904
def AddStep(self, step):
2905
assert isinstance(step, ExecutionStep)
2906
step._notify_is_used()
2907
if not step.HasNets() and not step.HasSubsteps():
2909
self._plan.execution_step.add().CopyFrom(step.Proto())
2910
self._steps.append(step)
2913
add_nets_in_order(step, net_list)
2914
self.AddNets([step.get_net(n) for n in net_list])
2919
def get_all_attributes(self, name):
2921
Return the list of all attributes under the given `name`, present in
2922
all of the nets used in this plan.
2926
for net in self._net_dict.values()
2927
for attr in net.get_attributes(name)
2931
def create_from_proto(cls, plan_proto):
2932
assert isinstance(plan_proto, caffe2_pb2.PlanDef)
2933
plan = Plan(plan_proto.name)
2934
plan._plan.CopyFrom(plan_proto)
2935
del plan._plan.network[:]
2936
del plan._plan.execution_step[:]
2940
for net_proto in plan_proto.network:
2941
assert net_proto.name not in net_proto_dict
2942
net_proto_dict[net_proto.name] = net_proto
2944
for step_proto in plan_proto.execution_step:
2945
step = ExecutionStep.create_from_proto(
2946
step_proto, net_obj_dict, net_proto_dict)
2952
def to_execution_step(step_or_nets, default_name=None):
2953
from caffe2.python.net_builder import NetBuilder
2954
if isinstance(step_or_nets, ExecutionStep):
2958
if not default_name and hasattr(step_or_nets, 'name'):
2959
default_name = step_or_nets.name
2960
if isinstance(step_or_nets, NetBuilder):
2961
stop_blob = step_or_nets._stop_blob
2962
step_or_nets = step_or_nets.get()
2963
return execution_step(
2964
default_name, step_or_nets, should_stop_blob=stop_blob)
2967
def execution_step(default_name,
2971
report_interval=None,
2972
concurrent_substeps=None,
2973
should_stop_blob=None,
2975
num_concurrent_instances=None,
2976
create_workspace=False,
2979
Helper for creating an ExecutionStep.
2980
- steps_or_nets can be:
2985
- list<ExecutionStep>
2986
- should_stop_blob is either None or a scalar boolean blob.
2987
- This blob is checked AFTER every substeps/subnets.
2988
- If specified and true, then this step will return immediately.
2989
- Be sure to handle race conditions if setting from concurrent threads.
2990
- if no should_stop_blob or num_iter is provided, defaults to num_iter=1
2992
assert should_stop_blob is None or num_iter is None, (
2993
'Cannot set both should_stop_blob and num_iter.')
2994
if should_stop_blob is None and num_iter is None:
2997
step = ExecutionStep(default_name)
2998
if should_stop_blob is not None:
2999
step.SetShouldStopBlob(should_stop_blob)
3000
if num_iter is not None:
3001
step.SetIter(num_iter)
3002
if only_once is not None:
3003
step.SetOnlyOnce(only_once)
3004
if concurrent_substeps is not None:
3005
step.SetConcurrentSubsteps(concurrent_substeps)
3006
if report_net is not None:
3007
assert report_interval is not None
3008
step.SetReportNet(report_net, report_interval)
3009
if num_concurrent_instances is not None:
3010
step.SetNumConcurrentInstances(num_concurrent_instances)
3011
if create_workspace:
3012
step.SetCreateWorkspace(True)
3014
step.RunEveryMillis(run_every_ms)
3016
if isinstance(steps_or_nets, ExecutionStep):
3017
step.AddSubstep(steps_or_nets)
3018
elif isinstance(steps_or_nets, Net):
3019
step.AddNet(steps_or_nets)
3020
elif isinstance(steps_or_nets, list):
3021
if all(isinstance(x, Net) for x in steps_or_nets):
3022
for x in steps_or_nets:
3025
for x in steps_or_nets:
3026
step.AddSubstep(to_execution_step(x))
3029
'steps_or_nets must be a step, a net, or a list of nets or steps.')
3033
def scoped_execution_step(name, *args, **kwargs):
3034
"""Same as execution_step() except that the step name is scoped."""
3035
default_name = ScopedName(name) if name else name
3036
return execution_step(default_name, *args, **kwargs)
3039
def _extract_stacktrace():
3041
This function extracts stacktrace without file system access
3042
by purely using sys._getframe() and removes part that belongs to
3043
this file (core.py). We are not using inspect module because
3044
its just a wrapper on top of sys._getframe() whose
3045
logic is based on accessing source files on disk - exactly what
3046
we are trying to avoid here. Same stands for traceback module
3048
The reason for file system access avoidance is that
3049
if code is located on an NFS, file access might be slow
3051
Function returns a list of tuples (file_name, line_number, function)
3057
frame = sys._getframe(3)
3062
result.append((frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name))
3063
frame = frame.f_back
3067
SetPerOpEnginePref = C.set_per_op_engine_pref
3068
SetGlobalEnginePref = C.set_global_engine_pref
3069
SetEnginePref = C.set_engine_pref
3070
SetOpEnginePref = C.set_op_engine_pref