google-research
94 строки · 3.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contains functions used to build cnc and siamese net models."""
17from __future__ import division18
19from tensorflow.compat.v1.keras import layers20from tensorflow.compat.v1.keras.regularizers import l221
22
23def stack_layers(inputs, net_layers, kernel_initializer='glorot_uniform'):24"""Builds the architecture of the network by applying each layer specified in net_layers to inputs.25
26Args:
27inputs: a dict containing input_types and input_placeholders for each key
28and value pair, respecively.
29net_layers: a list of dicts containing all layers to be used in the
30network, where each dict describes one such layer. each dict requires the
31key 'type'. all other keys are dependent on the layer type.
32kernel_initializer: initialization configuration passed to keras (see keras
33initializers).
34
35Returns:
36outputs: a dict formatted in much the same way as inputs. it
37contains input_types and output_tensors for each key and value pair,
38respectively, where output_tensors are the outputs of the
39input_placeholders in inputs after each layer in net_layers is applied.
40"""
41outputs = dict()42
43for key in inputs:44outputs[key] = inputs[key]45
46for layer in net_layers:47# check for l2_reg argument48l2_reg = layer.get('l2_reg')49if l2_reg:50l2_reg = l2(layer['l2_reg'])51
52# create the layer53if layer['type'] in [54'softplus', 'softsign', 'softmax', 'tanh', 'sigmoid', 'relu', 'selu'55]:56l = layers.Dense(57layer['size'],58activation=layer['type'],59kernel_initializer=kernel_initializer,60kernel_regularizer=l2_reg,61name=layer.get('name'))62elif layer['type'] == 'None':63l = layers.Dense(64layer['size'],65kernel_initializer=kernel_initializer,66kernel_regularizer=l2_reg,67name=layer.get('name'))68elif layer['type'] == 'Conv2D':69l = layers.Conv2D(70layer['channels'],71kernel_size=layer['kernel'],72activation='relu',73data_format='channels_last',74kernel_regularizer=l2_reg,75name=layer.get('name'))76elif layer['type'] == 'BatchNormalization':77l = layers.BatchNormalization(name=layer.get('name'))78elif layer['type'] == 'MaxPooling2D':79l = layers.MaxPooling2D(80pool_size=layer['pool_size'],81data_format='channels_first',82name=layer.get('name'))83elif layer['type'] == 'Dropout':84l = layers.Dropout(layer['rate'], name=layer.get('name'))85elif layer['type'] == 'Flatten':86l = layers.Flatten(name=layer.get('name'))87else:88raise ValueError("Invalid layer type '{}'".format(layer['type']))89
90# apply the layer to each input in inputs91for k in outputs:92outputs[k] = l(outputs[k])93
94return outputs95