google-research

Форк
0
154 строки · 5.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=logging-format-interpolation
17
# pylint: disable=g-direct-tensorflow-import
18
# pylint: disable=protected-access
19

20
r"""Models."""
21

22

23
from absl import logging
24

25
import tensorflow.compat.v1 as tf
26
from differentiable_data_selection import modeling_utils as ops
27

28

29
class Wrn28k(object):
30
  """WideResNet."""
31

32
  def __init__(self, params, k=2):
33
    self.params = params
34
    self.name = f'wrn-28-{k}'
35
    self.k = k
36
    logging.info(f'Build `wrn-28-{k}` under scope `{self.name}`')
37

38
  def __call__(self, x, training, return_scores=False):
39
    if training:
40
      logging.info(f'Call {self.name} for `training`')
41
    else:
42
      logging.info(f'Call {self.name} for `eval`')
43

44
    params = self.params
45
    k = self.k
46
    if params.use_bfloat16:
47
      ops.use_bfloat16()
48

49
    s = [16, 135, 135*2, 135*4] if k == 135 else [16*k, 16*k, 32*k, 64*k]
50

51
    with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
52
      with tf.variable_scope('stem'):
53
        x = ops.conv2d(x, 3, s[0], 1)
54
        ops.log_tensor(x, True)
55

56
      x = ops.wrn_block(x, params, s[1], 1, training, 'block_1')
57
      x = ops.wrn_block(x, params, s[1], 1, training, 'block_2')
58
      x = ops.wrn_block(x, params, s[1], 1, training, 'block_3')
59
      x = ops.wrn_block(x, params, s[1], 1, training, 'block_4')
60

61
      x = ops.wrn_block(x, params, s[2], 2, training, 'block_5')
62
      x = ops.wrn_block(x, params, s[2], 1, training, 'block_6')
63
      x = ops.wrn_block(x, params, s[2], 1, training, 'block_7')
64
      x = ops.wrn_block(x, params, s[2], 1, training, 'block_8')
65

66
      x = ops.wrn_block(x, params, s[3], 2, training, 'block_9')
67
      x = ops.wrn_block(x, params, s[3], 1, training, 'block_10')
68
      x = ops.wrn_block(x, params, s[3], 1, training, 'block_11')
69
      x = ops.wrn_block(x, params, s[3], 1, training, 'block_12')
70

71
      with tf.variable_scope('head'):
72
        x = ops.batch_norm(x, params, training)
73
        x = ops.relu(x)
74
        x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool')
75
        ops.log_tensor(x, True)
76

77
        x = ops.dropout(x, params.dense_dropout_rate, training)
78
        if return_scores:
79
          x = ops.dense(x, 1, use_bias=False)
80
          x = params.scorer_clip * tf.tanh(x, name='scores')
81
        else:
82
          x = ops.dense(x, params.num_classes)
83
        x = tf.cast(x, dtype=tf.float32, name='logits')
84
        ops.log_tensor(x, True)
85

86
    return x
87

88

89
class ResNet50(object):
90
  """Bottleneck ResNet."""
91

92
  def __init__(self, params):
93
    self.params = params
94
    self.name = 'resnet-50'
95
    logging.info(f'Build `resnet-50` under scope `{self.name}`')
96

97
  def __call__(self, x, training):
98
    if training:
99
      logging.info(f'Call {self.name} for `training`')
100
    else:
101
      logging.info(f'Call {self.name} for `eval`')
102

103
    params = self.params
104
    if params.use_bfloat16:
105
      ops.use_bfloat16()
106

107
    def _block_fn(inputs, num_out_filters, stride, name):
108
      return ops.resnet_block(inputs,
109
                              params=params,
110
                              num_out_filters=num_out_filters,
111
                              stride=stride,
112
                              training=training,
113
                              name=name)
114

115
    with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
116
      with tf.variable_scope('stem'):
117
        x = ops.conv2d(x, 7, 64, 2)
118
        x = ops.batch_norm(x, params, training)
119
        x = ops.relu(x, leaky=0.)
120
        ops.log_tensor(x, True)
121

122
        x = ops.max_pool(x, 3, 2)
123
        ops.log_tensor(x, True)
124

125
      x = _block_fn(x, 256, 1, name='block_1')
126
      x = _block_fn(x, 256, 1, name='block_2')
127
      x = _block_fn(x, 256, 1, name='block_3')
128

129
      x = _block_fn(x, 512, 2, name='block_4')
130
      x = _block_fn(x, 512, 1, name='block_5')
131
      x = _block_fn(x, 512, 1, name='block_6')
132
      x = _block_fn(x, 512, 1, name='block_7')
133

134
      x = _block_fn(x, 1024, 2, name='block_8')
135
      x = _block_fn(x, 1024, 1, name='block_9')
136
      x = _block_fn(x, 1024, 1, name='block_10')
137
      x = _block_fn(x, 1024, 1, name='block_11')
138
      x = _block_fn(x, 1024, 1, name='block_12')
139
      x = _block_fn(x, 1024, 1, name='block_13')
140

141
      x = _block_fn(x, 2048, 2, name='block_14')
142
      x = _block_fn(x, 2048, 1, name='block_15')
143
      x = _block_fn(x, 2048, 1, name='block_16')
144

145
      with tf.variable_scope('head'):
146
        x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool')
147
        ops.log_tensor(x, True)
148

149
        x = ops.dropout(x, params.dense_dropout_rate, training)
150
        x = ops.dense(x, params.num_classes)
151
        x = tf.cast(x, dtype=tf.float32, name='logits')
152
        ops.log_tensor(x, True)
153

154
    return x
155

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.