google-research
151 строка · 4.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contains models used in the experiments.
17"""
18
19from __future__ import absolute_import20from __future__ import division21from __future__ import print_function22
23from absl import logging24import tensorflow.compat.v2 as tf25
26from cold_posterior_bnn.core import frn27from cold_posterior_bnn.imdb import imdb_model28
29
30def build_cnnlstm(num_words, sequence_length, pfac):31model = imdb_model.cnn_lstm_nd(pfac, num_words, sequence_length)32return model33
34
35def build_resnet_v1(input_shape, depth, num_classes, pfac, use_frn=False,36use_internal_bias=True):37"""Builds ResNet v1.38
39Args:
40input_shape: tf.Tensor.
41depth: ResNet depth.
42num_classes: Number of output classes.
43pfac: priorfactory.PriorFactory class.
44use_frn: if True, then use Filter Response Normalization (FRN) instead of
45batchnorm.
46use_internal_bias: if True, use biases in all Conv layers.
47If False, only use a bias in the final Dense layer.
48
49Returns:
50tf.keras.Model.
51"""
52def resnet_layer(inputs,53filters,54kernel_size=3,55strides=1,56activation=None,57pfac=None,58use_frn=False,59use_bias=True):60"""2D Convolution-Batch Normalization-Activation stack builder.61
62Args:
63inputs: tf.Tensor.
64filters: Number of filters for Conv2D.
65kernel_size: Kernel dimensions for Conv2D.
66strides: Stride dimensinons for Conv2D.
67activation: tf.keras.activations.Activation.
68pfac: prior.PriorFactory object.
69use_frn: if True, use Filter Response Normalization (FRN) layer
70use_bias: if True, use biases in Conv layers.
71
72Returns:
73tf.Tensor.
74"""
75x = inputs76logging.info('Applying conv layer.')77x = pfac(tf.keras.layers.Conv2D(78filters,79kernel_size=kernel_size,80strides=strides,81padding='same',82kernel_initializer='he_normal',83use_bias=use_bias))(x)84
85if use_frn:86x = pfac(frn.FRN())(x)87else:88x = tf.keras.layers.BatchNormalization()(x)89if activation is not None:90x = tf.keras.layers.Activation(activation)(x)91return x92
93# Main network code94num_res_blocks = (depth - 2) // 695filters = 1696if (depth - 2) % 6 != 0:97raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).')98
99logging.info('Starting ResNet build.')100inputs = tf.keras.layers.Input(shape=input_shape)101x = resnet_layer(inputs,102filters=filters,103activation='relu',104pfac=pfac,105use_frn=use_frn,106use_bias=use_internal_bias)107for stack in range(3):108for res_block in range(num_res_blocks):109logging.info('Starting ResNet stack #%d block #%d.', stack, res_block)110strides = 1111if stack > 0 and res_block == 0: # first layer but not first stack112strides = 2 # downsample113y = resnet_layer(x,114filters=filters,115strides=strides,116activation='relu',117pfac=pfac,118use_frn=use_frn,119use_bias=use_internal_bias)120y = resnet_layer(y,121filters=filters,122activation=None,123pfac=pfac,124use_frn=use_frn,125use_bias=use_internal_bias)126if stack > 0 and res_block == 0: # first layer but not first stack127# linear projection residual shortcut connection to match changed dims128x = resnet_layer(x,129filters=filters,130kernel_size=1,131strides=strides,132activation=None,133pfac=pfac,134use_frn=use_frn,135use_bias=use_internal_bias)136x = tf.keras.layers.add([x, y])137if use_frn:138x = pfac(frn.TLU())(x)139else:140x = tf.keras.layers.Activation('relu')(x)141filters *= 2142
143# v1 does not use BN after last shortcut connection-ReLU144x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)145x = tf.keras.layers.Flatten()(x)146x = pfac(tf.keras.layers.Dense(147num_classes,148kernel_initializer='he_normal'))(x)149
150logging.info('ResNet successfully built.')151return tf.keras.models.Model(inputs=inputs, outputs=x)152