google-research

Форк
0
94 строки · 3.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contains functions used to build cnc and siamese net models."""
17
from __future__ import division
18

19
from tensorflow.compat.v1.keras import layers
20
from tensorflow.compat.v1.keras.regularizers import l2
21

22

23
def stack_layers(inputs, net_layers, kernel_initializer='glorot_uniform'):
24
  """Builds the architecture of the network by applying each layer specified in net_layers to inputs.
25

26
  Args:
27
    inputs: a dict containing input_types and input_placeholders for each key
28
      and value pair, respecively.
29
    net_layers:  a list of dicts containing all layers to be used in the
30
      network, where each dict describes one such layer. each dict requires the
31
      key 'type'. all other keys are dependent on the layer type.
32
    kernel_initializer: initialization configuration passed to keras (see keras
33
      initializers).
34

35
  Returns:
36
    outputs: a dict formatted in much the same way as inputs. it
37
      contains input_types and output_tensors for each key and value pair,
38
      respectively, where output_tensors are the outputs of the
39
      input_placeholders in inputs after each layer in net_layers is applied.
40
  """
41
  outputs = dict()
42

43
  for key in inputs:
44
    outputs[key] = inputs[key]
45

46
  for layer in net_layers:
47
    # check for l2_reg argument
48
    l2_reg = layer.get('l2_reg')
49
    if l2_reg:
50
      l2_reg = l2(layer['l2_reg'])
51

52
    # create the layer
53
    if layer['type'] in [
54
        'softplus', 'softsign', 'softmax', 'tanh', 'sigmoid', 'relu', 'selu'
55
    ]:
56
      l = layers.Dense(
57
          layer['size'],
58
          activation=layer['type'],
59
          kernel_initializer=kernel_initializer,
60
          kernel_regularizer=l2_reg,
61
          name=layer.get('name'))
62
    elif layer['type'] == 'None':
63
      l = layers.Dense(
64
          layer['size'],
65
          kernel_initializer=kernel_initializer,
66
          kernel_regularizer=l2_reg,
67
          name=layer.get('name'))
68
    elif layer['type'] == 'Conv2D':
69
      l = layers.Conv2D(
70
          layer['channels'],
71
          kernel_size=layer['kernel'],
72
          activation='relu',
73
          data_format='channels_last',
74
          kernel_regularizer=l2_reg,
75
          name=layer.get('name'))
76
    elif layer['type'] == 'BatchNormalization':
77
      l = layers.BatchNormalization(name=layer.get('name'))
78
    elif layer['type'] == 'MaxPooling2D':
79
      l = layers.MaxPooling2D(
80
          pool_size=layer['pool_size'],
81
          data_format='channels_first',
82
          name=layer.get('name'))
83
    elif layer['type'] == 'Dropout':
84
      l = layers.Dropout(layer['rate'], name=layer.get('name'))
85
    elif layer['type'] == 'Flatten':
86
      l = layers.Flatten(name=layer.get('name'))
87
    else:
88
      raise ValueError("Invalid layer type '{}'".format(layer['type']))
89

90
    # apply the layer to each input in inputs
91
    for k in outputs:
92
      outputs[k] = l(outputs[k])
93

94
  return outputs
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.