pytorch

Форк
0
139 строк · 4.7 Кб
1
## @package model_helper_api
2
# Module caffe2.python.model_helper_api
3

4

5

6

7

8
import sys
9
import copy
10
import inspect
11
from past.builtins import basestring
12
from caffe2.python.model_helper import ModelHelper
13

14
# flake8: noqa
15
from caffe2.python.helpers.algebra import *
16
from caffe2.python.helpers.arg_scope import *
17
from caffe2.python.helpers.array_helpers import *
18
from caffe2.python.helpers.control_ops import *
19
from caffe2.python.helpers.conv import *
20
from caffe2.python.helpers.db_input import *
21
from caffe2.python.helpers.dropout import *
22
from caffe2.python.helpers.elementwise_linear import *
23
from caffe2.python.helpers.fc import *
24
from caffe2.python.helpers.nonlinearity import *
25
from caffe2.python.helpers.normalization import *
26
from caffe2.python.helpers.pooling import *
27
from caffe2.python.helpers.quantization import *
28
from caffe2.python.helpers.tools import *
29
from caffe2.python.helpers.train import *
30

31

32
class HelperWrapper(object):
33
    _registry = {
34
        'arg_scope': arg_scope,
35
        'fc': fc,
36
        'packed_fc': packed_fc,
37
        'fc_decomp': fc_decomp,
38
        'fc_sparse': fc_sparse,
39
        'fc_prune': fc_prune,
40
        'dropout': dropout,
41
        'max_pool': max_pool,
42
        'average_pool': average_pool,
43
        'max_pool_with_index' : max_pool_with_index,
44
        'lrn': lrn,
45
        'softmax': softmax,
46
        'instance_norm': instance_norm,
47
        'spatial_bn': spatial_bn,
48
        'spatial_gn': spatial_gn,
49
        'moments_with_running_stats': moments_with_running_stats,
50
        'relu': relu,
51
        'prelu': prelu,
52
        'tanh': tanh,
53
        'concat': concat,
54
        'depth_concat': depth_concat,
55
        'sum': sum,
56
        'reduce_sum': reduce_sum,
57
        'sub': sub,
58
        'arg_min': arg_min,
59
        'transpose': transpose,
60
        'iter': iter,
61
        'accuracy': accuracy,
62
        'conv': conv,
63
        'conv_nd': conv_nd,
64
        'conv_transpose': conv_transpose,
65
        'group_conv': group_conv,
66
        'group_conv_deprecated': group_conv_deprecated,
67
        'image_input': image_input,
68
        'video_input': video_input,
69
        'add_weight_decay': add_weight_decay,
70
        'elementwise_linear': elementwise_linear,
71
        'layer_norm': layer_norm,
72
        'mat_mul' : mat_mul,
73
        'batch_mat_mul' : batch_mat_mul,
74
        'cond' : cond,
75
        'loop' : loop,
76
        'db_input' : db_input,
77
        'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float,
78
        'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse,
79
    }
80

81
    def __init__(self, wrapped):
82
        self.wrapped = wrapped
83

84
    def __getattr__(self, helper_name):
85
        if helper_name not in self._registry:
86
            raise AttributeError(
87
                "Helper function {} not "
88
                "registered.".format(helper_name)
89
            )
90

91
        def scope_wrapper(*args, **kwargs):
92
            new_kwargs = {}
93
            if helper_name != 'arg_scope':
94
                if len(args) > 0 and isinstance(args[0], ModelHelper):
95
                    model = args[0]
96
                elif 'model' in kwargs:
97
                    model = kwargs['model']
98
                else:
99
                    raise RuntimeError(
100
                "The first input of helper function should be model. " \
101
                "Or you can provide it in kwargs as model=<your_model>.")
102
                new_kwargs = copy.deepcopy(model.arg_scope)
103
            func = self._registry[helper_name]
104
            var_names, _, varkw, _= inspect.getargspec(func)
105
            if varkw is None:
106
                # this helper function does not take in random **kwargs
107
                new_kwargs = {
108
                    var_name: new_kwargs[var_name]
109
                    for var_name in var_names if var_name in new_kwargs
110
                }
111

112
            cur_scope = get_current_scope()
113
            new_kwargs.update(cur_scope.get(helper_name, {}))
114
            new_kwargs.update(kwargs)
115
            return func(*args, **new_kwargs)
116

117
        scope_wrapper.__name__ = helper_name
118
        return scope_wrapper
119

120
    def Register(self, helper):
121
        name = helper.__name__
122
        if name in self._registry:
123
            raise AttributeError(
124
                "Helper {} already exists. Please change your "
125
                "helper name.".format(name)
126
            )
127
        self._registry[name] = helper
128

129
    def has_helper(self, helper_or_helper_name):
130
        helper_name = (
131
            helper_or_helper_name
132
            if isinstance(helper_or_helper_name, basestring) else
133
            helper_or_helper_name.__name__
134
        )
135
        return helper_name in self._registry
136

137

138
# pyre-fixme[6]: incompatible parameter type: expected ModuleType, got HelperWrapper
139
sys.modules[__name__] = HelperWrapper(sys.modules[__name__])
140

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.