google-research

Форк
0
151 строка · 4.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contains models used in the experiments.
17
"""
18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
from absl import logging
24
import tensorflow.compat.v2 as tf
25

26
from cold_posterior_bnn.core import frn
27
from cold_posterior_bnn.imdb import imdb_model
28

29

30
def build_cnnlstm(num_words, sequence_length, pfac):
31
  model = imdb_model.cnn_lstm_nd(pfac, num_words, sequence_length)
32
  return model
33

34

35
def build_resnet_v1(input_shape, depth, num_classes, pfac, use_frn=False,
36
                    use_internal_bias=True):
37
  """Builds ResNet v1.
38

39
  Args:
40
    input_shape: tf.Tensor.
41
    depth: ResNet depth.
42
    num_classes: Number of output classes.
43
    pfac: priorfactory.PriorFactory class.
44
    use_frn: if True, then use Filter Response Normalization (FRN) instead of
45
      batchnorm.
46
    use_internal_bias: if True, use biases in all Conv layers.
47
      If False, only use a bias in the final Dense layer.
48

49
  Returns:
50
    tf.keras.Model.
51
  """
52
  def resnet_layer(inputs,
53
                   filters,
54
                   kernel_size=3,
55
                   strides=1,
56
                   activation=None,
57
                   pfac=None,
58
                   use_frn=False,
59
                   use_bias=True):
60
    """2D Convolution-Batch Normalization-Activation stack builder.
61

62
    Args:
63
      inputs: tf.Tensor.
64
      filters: Number of filters for Conv2D.
65
      kernel_size: Kernel dimensions for Conv2D.
66
      strides: Stride dimensinons for Conv2D.
67
      activation: tf.keras.activations.Activation.
68
      pfac: prior.PriorFactory object.
69
      use_frn: if True, use Filter Response Normalization (FRN) layer
70
      use_bias: if True, use biases in Conv layers.
71

72
    Returns:
73
      tf.Tensor.
74
    """
75
    x = inputs
76
    logging.info('Applying conv layer.')
77
    x = pfac(tf.keras.layers.Conv2D(
78
        filters,
79
        kernel_size=kernel_size,
80
        strides=strides,
81
        padding='same',
82
        kernel_initializer='he_normal',
83
        use_bias=use_bias))(x)
84

85
    if use_frn:
86
      x = pfac(frn.FRN())(x)
87
    else:
88
      x = tf.keras.layers.BatchNormalization()(x)
89
    if activation is not None:
90
      x = tf.keras.layers.Activation(activation)(x)
91
    return x
92

93
  # Main network code
94
  num_res_blocks = (depth - 2) // 6
95
  filters = 16
96
  if (depth - 2) % 6 != 0:
97
    raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).')
98

99
  logging.info('Starting ResNet build.')
100
  inputs = tf.keras.layers.Input(shape=input_shape)
101
  x = resnet_layer(inputs,
102
                   filters=filters,
103
                   activation='relu',
104
                   pfac=pfac,
105
                   use_frn=use_frn,
106
                   use_bias=use_internal_bias)
107
  for stack in range(3):
108
    for res_block in range(num_res_blocks):
109
      logging.info('Starting ResNet stack #%d block #%d.', stack, res_block)
110
      strides = 1
111
      if stack > 0 and res_block == 0:  # first layer but not first stack
112
        strides = 2  # downsample
113
      y = resnet_layer(x,
114
                       filters=filters,
115
                       strides=strides,
116
                       activation='relu',
117
                       pfac=pfac,
118
                       use_frn=use_frn,
119
                       use_bias=use_internal_bias)
120
      y = resnet_layer(y,
121
                       filters=filters,
122
                       activation=None,
123
                       pfac=pfac,
124
                       use_frn=use_frn,
125
                       use_bias=use_internal_bias)
126
      if stack > 0 and res_block == 0:  # first layer but not first stack
127
        # linear projection residual shortcut connection to match changed dims
128
        x = resnet_layer(x,
129
                         filters=filters,
130
                         kernel_size=1,
131
                         strides=strides,
132
                         activation=None,
133
                         pfac=pfac,
134
                         use_frn=use_frn,
135
                         use_bias=use_internal_bias)
136
      x = tf.keras.layers.add([x, y])
137
      if use_frn:
138
        x = pfac(frn.TLU())(x)
139
      else:
140
        x = tf.keras.layers.Activation('relu')(x)
141
    filters *= 2
142

143
  # v1 does not use BN after last shortcut connection-ReLU
144
  x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)
145
  x = tf.keras.layers.Flatten()(x)
146
  x = pfac(tf.keras.layers.Dense(
147
      num_classes,
148
      kernel_initializer='he_normal'))(x)
149

150
  logging.info('ResNet successfully built.')
151
  return tf.keras.models.Model(inputs=inputs, outputs=x)
152

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.