google-research

Форк
0
166 строк · 6.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""ResNet V1 implementation for UQ experiments.
17

18
Mostly derived from Keras documentation:
19
  https://keras.io/examples/cifar10_resnet/
20
"""
21

22
from __future__ import absolute_import
23
from __future__ import division
24
from __future__ import print_function
25

26
import functools
27
from absl import logging
28

29
from six.moves import range
30
import tensorflow.compat.v2 as tf
31
import tensorflow_probability as tfp
32
from uq_benchmark_2019 import uq_utils
33
keras = tf.keras
34
tfd = tfp.distributions
35

36

37
def _resnet_layer(inputs,
38
                  num_filters=16,
39
                  kernel_size=3,
40
                  strides=1,
41
                  activation='relu',
42
                  depth=20,
43
                  batch_norm=True,
44
                  conv_first=True,
45
                  variational=False,
46
                  std_prior_scale=1.5,
47
                  eb_prior_fn=None,
48
                  always_on_dropout_rate=None,
49
                  examples_per_epoch=None):
50
  """2D Convolution-Batch Normalization-Activation stack builder.
51

52
  Args:
53
    inputs (tensor): input tensor from input image or previous layer
54
    num_filters (int): Conv2D number of filters
55
    kernel_size (int): Conv2D square kernel dimensions
56
    strides (int): Conv2D square stride dimensions
57
    activation (string): Activation function string.
58
    depth (int): ResNet depth; used for initialization scale.
59
    batch_norm (bool): whether to include batch normalization
60
    conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)
61
    variational (bool): Whether to use a variational convolutional layer.
62
    std_prior_scale (float): Scale for log-normal hyperprior.
63
    eb_prior_fn (callable): Empirical Bayes prior for use with TFP layers.
64
    always_on_dropout_rate (float): Dropout rate (active in train and test).
65
    examples_per_epoch (int): Number of examples per epoch for variational KL.
66

67
  Returns:
68
      x (tensor): tensor as input to the next layer
69
  """
70
  if variational:
71
    divergence_fn = uq_utils.make_divergence_fn_for_empirical_bayes(
72
        std_prior_scale, examples_per_epoch)
73

74
    def fixup_init(shape, dtype=None):
75
      """Fixup initialization; see https://arxiv.org/abs/1901.09321."""
76
      return keras.initializers.he_normal()(shape, dtype=dtype) * depth**(-1/4)
77

78
    conv = tfp.layers.Convolution2DFlipout(
79
        num_filters,
80
        kernel_size=kernel_size,
81
        strides=strides,
82
        padding='same',
83
        kernel_prior_fn=eb_prior_fn,
84
        kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(
85
            loc_initializer=fixup_init),
86
        kernel_divergence_fn=divergence_fn)
87
  else:
88
    conv = keras.layers.Conv2D(num_filters,
89
                               kernel_size=kernel_size,
90
                               strides=strides,
91
                               padding='same',
92
                               kernel_initializer='he_normal',
93
                               kernel_regularizer=keras.regularizers.l2(1e-4))
94

95
  def apply_conv(net):
96
    logging.info('Applying conv layer; always_on_dropout=%s.',
97
                 always_on_dropout_rate)
98
    if always_on_dropout_rate:
99
      net = keras.layers.Dropout(always_on_dropout_rate)(net, training=True)
100
    return conv(net)
101

102
  x = inputs
103
  x = apply_conv(x) if conv_first else x
104
  x = (keras.layers.BatchNormalization()(x)
105
       if batch_norm and not variational else x)
106
  x = keras.layers.Activation(activation)(x) if activation is not None else x
107
  x = x if conv_first else apply_conv(x)
108
  return x
109

110

111
def build_resnet_v1(input_layer, depth,
112
                    variational,
113
                    std_prior_scale,
114
                    always_on_dropout_rate,
115
                    examples_per_epoch,
116
                    eb_prior_fn=None,
117
                    no_first_layer_dropout=False,
118
                    num_filters=16):
119
  """ResNet Version 1 Model builder [a]."""
120
  if (depth - 2) % 6 != 0:
121
    raise ValueError('depth should be 6n+2 (eg 20, 32, 44 in [a])')
122
  # Start model definition.
123
  num_res_blocks = int((depth - 2) / 6)
124

125
  activation = 'selu' if variational else 'relu'
126
  resnet_layer = functools.partial(
127
      _resnet_layer,
128
      activation=activation,
129
      depth=depth,
130
      std_prior_scale=std_prior_scale,
131
      always_on_dropout_rate=always_on_dropout_rate,
132
      examples_per_epoch=examples_per_epoch,
133
      eb_prior_fn=eb_prior_fn)
134

135
  logging.info('Starting ResNet build.')
136
  x = resnet_layer(
137
      inputs=input_layer,
138
      num_filters=num_filters,
139
      always_on_dropout_rate=(None if no_first_layer_dropout
140
                              else always_on_dropout_rate))
141
  # Instantiate the stack of residual units
142
  for stack in range(3):
143
    for res_block in range(num_res_blocks):
144
      logging.info('Starting ResNet stack #%d block #%d.', stack, res_block)
145
      strides = 1
146
      if stack > 0 and res_block == 0:  # first layer but not first stack
147
        strides = 2  # downsample
148
      y = resnet_layer(inputs=x, num_filters=num_filters, strides=strides)
149
      y = resnet_layer(inputs=y, num_filters=num_filters, activation=None,
150
                       variational=variational)
151
      if stack > 0 and res_block == 0:  # first layer but not first stack
152
        # linear projection residual shortcut connection to match changed dims
153
        x = resnet_layer(inputs=x,
154
                         num_filters=num_filters,
155
                         kernel_size=1,
156
                         strides=strides,
157
                         activation=None,
158
                         batch_norm=False)
159
      x = keras.layers.add([x, y])
160
      x = keras.layers.Activation(activation)(x)
161
    num_filters *= 2
162

163
  # Add classifier on top.
164
  # v1 does not use BN after last shortcut connection-ReLU
165
  x = keras.layers.AveragePooling2D(pool_size=8)(x)
166
  return keras.layers.Flatten()(x)
167

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.