pytorch

Форк
0
362 строки · 10.2 Кб
1
## @package conv
2
# Module caffe2.python.helpers.conv
3

4

5

6

7

8
from caffe2.python import core
9
from caffe2.python.modeling import initializers
10
from caffe2.python.modeling.parameter_info import ParameterTags
11

12
def _ConvBase(
13
    model,
14
    is_nd,
15
    blob_in,
16
    blob_out,
17
    dim_in,
18
    dim_out,
19
    kernel,
20
    weight_init=None,
21
    bias_init=None,
22
    WeightInitializer=None,
23
    BiasInitializer=None,
24
    group=1,
25
    transform_inputs=None,
26
    use_cudnn=False,
27
    order="NCHW",
28
    cudnn_exhaustive_search=False,
29
    ws_nbytes_limit=None,
30
    float16_compute=False,
31
    **kwargs
32
):
33
    kernels = []
34
    if is_nd:
35
        if not isinstance(kernel, list):
36
            kernels = [kernel]
37
        else:
38
            kernels = kernel
39
    else:
40
        if isinstance(kernel, list):
41
            assert len(kernel) == 2, "Conv support only a 2D kernel."
42
            kernels = kernel
43
        else:
44
            kernels = [kernel] * 2
45

46
    requested_engine = kwargs.get('engine')
47
    if requested_engine is not None:
48
        if use_cudnn and requested_engine != 'CUDNN':
49
            raise ValueError(
50
                'When use_cudnn=True, the only engine you can specify is '
51
                '"CUDNN"')
52
        elif not use_cudnn and requested_engine == 'CUDNN':
53
            raise ValueError(
54
                'When use_cudnn=False, the only engine you can specify is '
55
                '""')
56

57
    if use_cudnn:
58
        kwargs['engine'] = 'CUDNN'
59
        kwargs['exhaustive_search'] = cudnn_exhaustive_search
60
        if ws_nbytes_limit:
61
            kwargs['ws_nbytes_limit'] = ws_nbytes_limit
62

63
    use_bias =\
64
            False if ("no_bias" in kwargs and kwargs["no_bias"]) else True
65
    blob_out = blob_out or model.net.NextName()
66
    weight_shape = [dim_out]
67
    if order == "NCHW":
68
        weight_shape.append(int(dim_in / group))
69
        weight_shape.extend(kernels)
70
    else:
71
        weight_shape.extend(kernels)
72
        weight_shape.append(int(dim_in / group))
73

74
    WeightInitializer = initializers.update_initializer(
75
        WeightInitializer, weight_init, ("XavierFill", {})
76
    )
77
    BiasInitializer = initializers.update_initializer(
78
        BiasInitializer, bias_init, ("ConstantFill", {})
79
    )
80
    if not model.init_params:
81
        WeightInitializer = initializers.ExternalInitializer()
82
        BiasInitializer = initializers.ExternalInitializer()
83

84
    weight = model.create_param(
85
        param_name=blob_out + '_w',
86
        shape=weight_shape,
87
        initializer=WeightInitializer,
88
        tags=ParameterTags.WEIGHT
89
    )
90
    if use_bias:
91
        bias = model.create_param(
92
            param_name=blob_out + '_b',
93
            shape=[dim_out, ],
94
            initializer=BiasInitializer,
95
            tags=ParameterTags.BIAS
96
        )
97

98
    if use_bias:
99
        inputs = [blob_in, weight, bias]
100
    else:
101
        inputs = [blob_in, weight]
102

103
    if transform_inputs is not None:
104
        transform_inputs(model, blob_out, inputs)
105

106
    # Enable float 16 compute kernel (relevant for CUDA)
107
    if float16_compute:
108
        kwargs['float16_compute'] = True
109

110
    # For the operator, we no longer need to provide the no_bias field
111
    # because it can automatically figure this out from the number of
112
    # inputs.
113
    if 'no_bias' in kwargs:
114
        del kwargs['no_bias']
115
    if group != 1:
116
        kwargs['group'] = group
117
    if is_nd:
118
        return model.net.Conv(
119
            inputs,
120
            blob_out,
121
            kernels=kernels,
122
            order=order,
123
            **kwargs)
124
    else:
125
        if isinstance(kernel, list):
126
            return model.net.Conv(
127
                inputs,
128
                blob_out,
129
                kernel_h=kernel[0],
130
                kernel_w=kernel[1],
131
                order=order,
132
                **kwargs)
133
        else:
134
            return model.net.Conv(
135
                inputs,
136
                blob_out,
137
                kernel=kernel,
138
                order=order,
139
                **kwargs)
140

141

142

143
def conv_nd(
144
    model,
145
    blob_in,
146
    blob_out,
147
    dim_in,
148
    dim_out,
149
    kernel,
150
    weight_init=None,
151
    bias_init=None,
152
    WeightInitializer=None,
153
    BiasInitializer=None,
154
    group=1,
155
    transform_inputs=None,
156
    order="NCHW",
157
    **kwargs
158
):
159
    """N-dimensional convolution for inputs with NCHW storage order.
160
    """
161
    assert order == "NCHW", "ConvNd only supported for NCHW storage."
162
    return _ConvBase(model, True, blob_in, blob_out, dim_in, dim_out, kernel,
163
                     weight_init, bias_init, WeightInitializer, BiasInitializer,
164
                     group, transform_inputs, order=order, **kwargs)
165

166

167
def conv(
168
    model,
169
    blob_in,
170
    blob_out,
171
    dim_in,
172
    dim_out,
173
    kernel,
174
    weight_init=None,
175
    bias_init=None,
176
    WeightInitializer=None,
177
    BiasInitializer=None,
178
    group=1,
179
    transform_inputs=None,
180
    **kwargs
181
):
182
    """2-dimensional convolution.
183
    """
184
    return _ConvBase(model, False, blob_in, blob_out, dim_in, dim_out, kernel,
185
                     weight_init, bias_init, WeightInitializer, BiasInitializer,
186
                     group, transform_inputs, **kwargs)
187

188

189
def conv_transpose(
190
    model,
191
    blob_in,
192
    blob_out,
193
    dim_in,
194
    dim_out,
195
    kernel,
196
    weight_init=None,
197
    bias_init=None,
198
    use_cudnn=False,
199
    order="NCHW",
200
    cudnn_exhaustive_search=False,
201
    ws_nbytes_limit=None,
202
    **kwargs
203
):
204
    """ConvTranspose.
205
    """
206
    weight_init = weight_init if weight_init else ('XavierFill', {})
207
    bias_init = bias_init if bias_init else ('ConstantFill', {})
208
    blob_out = blob_out or model.net.NextName()
209
    weight_shape = (
210
        [dim_in, dim_out, kernel, kernel]
211
        if order == "NCHW" else [dim_in, kernel, kernel, dim_out]
212
    )
213
    if model.init_params:
214
        weight = model.param_init_net.__getattr__(weight_init[0])(
215
            [],
216
            blob_out + '_w',
217
            shape=weight_shape,
218
            **weight_init[1]
219
        )
220
        bias = model.param_init_net.__getattr__(bias_init[0])(
221
            [],
222
            blob_out + '_b',
223
            shape=[dim_out, ],
224
            **bias_init[1]
225
        )
226
    else:
227
        weight = core.ScopedBlobReference(
228
            blob_out + '_w', model.param_init_net)
229
        bias = core.ScopedBlobReference(
230
            blob_out + '_b', model.param_init_net)
231
    model.AddParameter(weight, ParameterTags.WEIGHT)
232
    model.AddParameter(bias, ParameterTags.BIAS)
233
    if use_cudnn:
234
        kwargs['engine'] = 'CUDNN'
235
        kwargs['exhaustive_search'] = cudnn_exhaustive_search
236
        if ws_nbytes_limit:
237
            kwargs['ws_nbytes_limit'] = ws_nbytes_limit
238
    return model.net.ConvTranspose(
239
        [blob_in, weight, bias],
240
        blob_out,
241
        kernel=kernel,
242
        order=order,
243
        **kwargs
244
    )
245

246

247
def group_conv(
248
    model,
249
    blob_in,
250
    blob_out,
251
    dim_in,
252
    dim_out,
253
    kernel,
254
    weight_init=None,
255
    bias_init=None,
256
    group=1,
257
    **kwargs
258
):
259
    """Group Convolution.
260

261
    This is essentially the same as Conv with a group argument passed in.
262
    We specialize this for backward interface compatibility.
263
    """
264
    return conv(model, blob_in, blob_out, dim_in, dim_out, kernel,
265
                weight_init=weight_init, bias_init=bias_init,
266
                group=group, **kwargs)
267

268

269
def group_conv_deprecated(
270
    model,
271
    blob_in,
272
    blob_out,
273
    dim_in,
274
    dim_out,
275
    kernel,
276
    weight_init=None,
277
    bias_init=None,
278
    group=1,
279
    use_cudnn=False,
280
    order="NCHW",
281
    cudnn_exhaustive_search=False,
282
    ws_nbytes_limit=None,
283
    **kwargs
284
):
285
    """GroupConvolution's deprecated interface.
286

287
    This is used to simulate a group convolution via split and concat. You
288
    should always use the new group convolution in your new code.
289
    """
290
    weight_init = weight_init if weight_init else ('XavierFill', {})
291
    bias_init = bias_init if bias_init else ('ConstantFill', {})
292
    use_bias = False if ("no_bias" in kwargs and kwargs["no_bias"]) else True
293
    if use_cudnn:
294
        kwargs['engine'] = 'CUDNN'
295
        kwargs['exhaustive_search'] = cudnn_exhaustive_search
296
        if ws_nbytes_limit:
297
            kwargs['ws_nbytes_limit'] = ws_nbytes_limit
298
            if dim_in % group:
299
                raise ValueError("dim_in should be divisible by group.")
300
    if dim_out % group:
301
        raise ValueError("dim_out should be divisible by group.")
302
    splitted_blobs = model.net.DepthSplit(
303
        blob_in,
304
        ['_' + blob_out + '_gconv_split_' + str(i) for i in range(group)],
305
        dimensions=[int(dim_in / group) for i in range(group)],
306
        order=order
307
    )
308
    weight_shape = (
309
        [dim_out / group, dim_in / group, kernel, kernel]
310
        if order == "NCHW" else
311
        [dim_out / group, kernel, kernel, dim_in / group]
312
    )
313
    # Make sure that the shapes are of int format. Especially for py3 where
314
    # int division gives float output.
315
    weight_shape = [int(v) for v in weight_shape]
316
    conv_blobs = []
317
    for i in range(group):
318
        if model.init_params:
319
            weight = model.param_init_net.__getattr__(weight_init[0])(
320
                [],
321
                blob_out + '_gconv_%d_w' % i,
322
                shape=weight_shape,
323
                **weight_init[1]
324
            )
325
            if use_bias:
326
                bias = model.param_init_net.__getattr__(bias_init[0])(
327
                    [],
328
                    blob_out + '_gconv_%d_b' % i,
329
                    shape=[int(dim_out / group)],
330
                    **bias_init[1]
331
                )
332
        else:
333
            weight = core.ScopedBlobReference(
334
                blob_out + '_gconv_%d_w' % i, model.param_init_net)
335
            if use_bias:
336
                bias = core.ScopedBlobReference(
337
                    blob_out + '_gconv_%d_b' % i, model.param_init_net)
338
        model.AddParameter(weight, ParameterTags.WEIGHT)
339
        if use_bias:
340
            model.AddParameter(bias, ParameterTags.BIAS)
341
        if use_bias:
342
            inputs = [weight, bias]
343
        else:
344
            inputs = [weight]
345
        if 'no_bias' in kwargs:
346
            del kwargs['no_bias']
347
        conv_blobs.append(
348
            splitted_blobs[i].Conv(
349
                inputs,
350
                blob_out + '_gconv_%d' % i,
351
                kernel=kernel,
352
                order=order,
353
                **kwargs
354
            )
355
        )
356
    concat, concat_dims = model.net.Concat(
357
        conv_blobs,
358
        [blob_out,
359
         "_" + blob_out + "_concat_dims"],
360
        order=order
361
    )
362
    return concat
363

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.