11
from past.builtins import basestring
12
from caffe2.python.model_helper import ModelHelper
15
from caffe2.python.helpers.algebra import *
16
from caffe2.python.helpers.arg_scope import *
17
from caffe2.python.helpers.array_helpers import *
18
from caffe2.python.helpers.control_ops import *
19
from caffe2.python.helpers.conv import *
20
from caffe2.python.helpers.db_input import *
21
from caffe2.python.helpers.dropout import *
22
from caffe2.python.helpers.elementwise_linear import *
23
from caffe2.python.helpers.fc import *
24
from caffe2.python.helpers.nonlinearity import *
25
from caffe2.python.helpers.normalization import *
26
from caffe2.python.helpers.pooling import *
27
from caffe2.python.helpers.quantization import *
28
from caffe2.python.helpers.tools import *
29
from caffe2.python.helpers.train import *
32
class HelperWrapper(object):
34
'arg_scope': arg_scope,
36
'packed_fc': packed_fc,
37
'fc_decomp': fc_decomp,
38
'fc_sparse': fc_sparse,
42
'average_pool': average_pool,
43
'max_pool_with_index' : max_pool_with_index,
46
'instance_norm': instance_norm,
47
'spatial_bn': spatial_bn,
48
'spatial_gn': spatial_gn,
49
'moments_with_running_stats': moments_with_running_stats,
54
'depth_concat': depth_concat,
56
'reduce_sum': reduce_sum,
59
'transpose': transpose,
64
'conv_transpose': conv_transpose,
65
'group_conv': group_conv,
66
'group_conv_deprecated': group_conv_deprecated,
67
'image_input': image_input,
68
'video_input': video_input,
69
'add_weight_decay': add_weight_decay,
70
'elementwise_linear': elementwise_linear,
71
'layer_norm': layer_norm,
73
'batch_mat_mul' : batch_mat_mul,
76
'db_input' : db_input,
77
'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float,
78
'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse,
81
def __init__(self, wrapped):
82
self.wrapped = wrapped
84
def __getattr__(self, helper_name):
85
if helper_name not in self._registry:
87
"Helper function {} not "
88
"registered.".format(helper_name)
91
def scope_wrapper(*args, **kwargs):
93
if helper_name != 'arg_scope':
94
if len(args) > 0 and isinstance(args[0], ModelHelper):
96
elif 'model' in kwargs:
97
model = kwargs['model']
100
"The first input of helper function should be model. " \
101
"Or you can provide it in kwargs as model=<your_model>.")
102
new_kwargs = copy.deepcopy(model.arg_scope)
103
func = self._registry[helper_name]
104
var_names, _, varkw, _= inspect.getargspec(func)
108
var_name: new_kwargs[var_name]
109
for var_name in var_names if var_name in new_kwargs
112
cur_scope = get_current_scope()
113
new_kwargs.update(cur_scope.get(helper_name, {}))
114
new_kwargs.update(kwargs)
115
return func(*args, **new_kwargs)
117
scope_wrapper.__name__ = helper_name
120
def Register(self, helper):
121
name = helper.__name__
122
if name in self._registry:
123
raise AttributeError(
124
"Helper {} already exists. Please change your "
125
"helper name.".format(name)
127
self._registry[name] = helper
129
def has_helper(self, helper_or_helper_name):
131
helper_or_helper_name
132
if isinstance(helper_or_helper_name, basestring) else
133
helper_or_helper_name.__name__
135
return helper_name in self._registry
139
sys.modules[__name__] = HelperWrapper(sys.modules[__name__])