google-research
154 строки · 5.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=logging-format-interpolation
17# pylint: disable=g-direct-tensorflow-import
18# pylint: disable=protected-access
19
20r"""Models."""
21
22
23from absl import logging24
25import tensorflow.compat.v1 as tf26from differentiable_data_selection import modeling_utils as ops27
28
29class Wrn28k(object):30"""WideResNet."""31
32def __init__(self, params, k=2):33self.params = params34self.name = f'wrn-28-{k}'35self.k = k36logging.info(f'Build `wrn-28-{k}` under scope `{self.name}`')37
38def __call__(self, x, training, return_scores=False):39if training:40logging.info(f'Call {self.name} for `training`')41else:42logging.info(f'Call {self.name} for `eval`')43
44params = self.params45k = self.k46if params.use_bfloat16:47ops.use_bfloat16()48
49s = [16, 135, 135*2, 135*4] if k == 135 else [16*k, 16*k, 32*k, 64*k]50
51with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):52with tf.variable_scope('stem'):53x = ops.conv2d(x, 3, s[0], 1)54ops.log_tensor(x, True)55
56x = ops.wrn_block(x, params, s[1], 1, training, 'block_1')57x = ops.wrn_block(x, params, s[1], 1, training, 'block_2')58x = ops.wrn_block(x, params, s[1], 1, training, 'block_3')59x = ops.wrn_block(x, params, s[1], 1, training, 'block_4')60
61x = ops.wrn_block(x, params, s[2], 2, training, 'block_5')62x = ops.wrn_block(x, params, s[2], 1, training, 'block_6')63x = ops.wrn_block(x, params, s[2], 1, training, 'block_7')64x = ops.wrn_block(x, params, s[2], 1, training, 'block_8')65
66x = ops.wrn_block(x, params, s[3], 2, training, 'block_9')67x = ops.wrn_block(x, params, s[3], 1, training, 'block_10')68x = ops.wrn_block(x, params, s[3], 1, training, 'block_11')69x = ops.wrn_block(x, params, s[3], 1, training, 'block_12')70
71with tf.variable_scope('head'):72x = ops.batch_norm(x, params, training)73x = ops.relu(x)74x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool')75ops.log_tensor(x, True)76
77x = ops.dropout(x, params.dense_dropout_rate, training)78if return_scores:79x = ops.dense(x, 1, use_bias=False)80x = params.scorer_clip * tf.tanh(x, name='scores')81else:82x = ops.dense(x, params.num_classes)83x = tf.cast(x, dtype=tf.float32, name='logits')84ops.log_tensor(x, True)85
86return x87
88
89class ResNet50(object):90"""Bottleneck ResNet."""91
92def __init__(self, params):93self.params = params94self.name = 'resnet-50'95logging.info(f'Build `resnet-50` under scope `{self.name}`')96
97def __call__(self, x, training):98if training:99logging.info(f'Call {self.name} for `training`')100else:101logging.info(f'Call {self.name} for `eval`')102
103params = self.params104if params.use_bfloat16:105ops.use_bfloat16()106
107def _block_fn(inputs, num_out_filters, stride, name):108return ops.resnet_block(inputs,109params=params,110num_out_filters=num_out_filters,111stride=stride,112training=training,113name=name)114
115with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):116with tf.variable_scope('stem'):117x = ops.conv2d(x, 7, 64, 2)118x = ops.batch_norm(x, params, training)119x = ops.relu(x, leaky=0.)120ops.log_tensor(x, True)121
122x = ops.max_pool(x, 3, 2)123ops.log_tensor(x, True)124
125x = _block_fn(x, 256, 1, name='block_1')126x = _block_fn(x, 256, 1, name='block_2')127x = _block_fn(x, 256, 1, name='block_3')128
129x = _block_fn(x, 512, 2, name='block_4')130x = _block_fn(x, 512, 1, name='block_5')131x = _block_fn(x, 512, 1, name='block_6')132x = _block_fn(x, 512, 1, name='block_7')133
134x = _block_fn(x, 1024, 2, name='block_8')135x = _block_fn(x, 1024, 1, name='block_9')136x = _block_fn(x, 1024, 1, name='block_10')137x = _block_fn(x, 1024, 1, name='block_11')138x = _block_fn(x, 1024, 1, name='block_12')139x = _block_fn(x, 1024, 1, name='block_13')140
141x = _block_fn(x, 2048, 2, name='block_14')142x = _block_fn(x, 2048, 1, name='block_15')143x = _block_fn(x, 2048, 1, name='block_16')144
145with tf.variable_scope('head'):146x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool')147ops.log_tensor(x, True)148
149x = ops.dropout(x, params.dense_dropout_rate, training)150x = ops.dense(x, params.num_classes)151x = tf.cast(x, dtype=tf.float32, name='logits')152ops.log_tensor(x, True)153
154return x155