google-research
166 строк · 6.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""ResNet V1 implementation for UQ experiments.
17
18Mostly derived from Keras documentation:
19https://keras.io/examples/cifar10_resnet/
20"""
21
22from __future__ import absolute_import23from __future__ import division24from __future__ import print_function25
26import functools27from absl import logging28
29from six.moves import range30import tensorflow.compat.v2 as tf31import tensorflow_probability as tfp32from uq_benchmark_2019 import uq_utils33keras = tf.keras34tfd = tfp.distributions35
36
37def _resnet_layer(inputs,38num_filters=16,39kernel_size=3,40strides=1,41activation='relu',42depth=20,43batch_norm=True,44conv_first=True,45variational=False,46std_prior_scale=1.5,47eb_prior_fn=None,48always_on_dropout_rate=None,49examples_per_epoch=None):50"""2D Convolution-Batch Normalization-Activation stack builder.51
52Args:
53inputs (tensor): input tensor from input image or previous layer
54num_filters (int): Conv2D number of filters
55kernel_size (int): Conv2D square kernel dimensions
56strides (int): Conv2D square stride dimensions
57activation (string): Activation function string.
58depth (int): ResNet depth; used for initialization scale.
59batch_norm (bool): whether to include batch normalization
60conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)
61variational (bool): Whether to use a variational convolutional layer.
62std_prior_scale (float): Scale for log-normal hyperprior.
63eb_prior_fn (callable): Empirical Bayes prior for use with TFP layers.
64always_on_dropout_rate (float): Dropout rate (active in train and test).
65examples_per_epoch (int): Number of examples per epoch for variational KL.
66
67Returns:
68x (tensor): tensor as input to the next layer
69"""
70if variational:71divergence_fn = uq_utils.make_divergence_fn_for_empirical_bayes(72std_prior_scale, examples_per_epoch)73
74def fixup_init(shape, dtype=None):75"""Fixup initialization; see https://arxiv.org/abs/1901.09321."""76return keras.initializers.he_normal()(shape, dtype=dtype) * depth**(-1/4)77
78conv = tfp.layers.Convolution2DFlipout(79num_filters,80kernel_size=kernel_size,81strides=strides,82padding='same',83kernel_prior_fn=eb_prior_fn,84kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(85loc_initializer=fixup_init),86kernel_divergence_fn=divergence_fn)87else:88conv = keras.layers.Conv2D(num_filters,89kernel_size=kernel_size,90strides=strides,91padding='same',92kernel_initializer='he_normal',93kernel_regularizer=keras.regularizers.l2(1e-4))94
95def apply_conv(net):96logging.info('Applying conv layer; always_on_dropout=%s.',97always_on_dropout_rate)98if always_on_dropout_rate:99net = keras.layers.Dropout(always_on_dropout_rate)(net, training=True)100return conv(net)101
102x = inputs103x = apply_conv(x) if conv_first else x104x = (keras.layers.BatchNormalization()(x)105if batch_norm and not variational else x)106x = keras.layers.Activation(activation)(x) if activation is not None else x107x = x if conv_first else apply_conv(x)108return x109
110
111def build_resnet_v1(input_layer, depth,112variational,113std_prior_scale,114always_on_dropout_rate,115examples_per_epoch,116eb_prior_fn=None,117no_first_layer_dropout=False,118num_filters=16):119"""ResNet Version 1 Model builder [a]."""120if (depth - 2) % 6 != 0:121raise ValueError('depth should be 6n+2 (eg 20, 32, 44 in [a])')122# Start model definition.123num_res_blocks = int((depth - 2) / 6)124
125activation = 'selu' if variational else 'relu'126resnet_layer = functools.partial(127_resnet_layer,128activation=activation,129depth=depth,130std_prior_scale=std_prior_scale,131always_on_dropout_rate=always_on_dropout_rate,132examples_per_epoch=examples_per_epoch,133eb_prior_fn=eb_prior_fn)134
135logging.info('Starting ResNet build.')136x = resnet_layer(137inputs=input_layer,138num_filters=num_filters,139always_on_dropout_rate=(None if no_first_layer_dropout140else always_on_dropout_rate))141# Instantiate the stack of residual units142for stack in range(3):143for res_block in range(num_res_blocks):144logging.info('Starting ResNet stack #%d block #%d.', stack, res_block)145strides = 1146if stack > 0 and res_block == 0: # first layer but not first stack147strides = 2 # downsample148y = resnet_layer(inputs=x, num_filters=num_filters, strides=strides)149y = resnet_layer(inputs=y, num_filters=num_filters, activation=None,150variational=variational)151if stack > 0 and res_block == 0: # first layer but not first stack152# linear projection residual shortcut connection to match changed dims153x = resnet_layer(inputs=x,154num_filters=num_filters,155kernel_size=1,156strides=strides,157activation=None,158batch_norm=False)159x = keras.layers.add([x, y])160x = keras.layers.Activation(activation)(x)161num_filters *= 2162
163# Add classifier on top.164# v1 does not use BN after last shortcut connection-ReLU165x = keras.layers.AveragePooling2D(pool_size=8)(x)166return keras.layers.Flatten()(x)167