google-research

Форк
0
98 строк · 3.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contains definitions for MLP Networks.
17
"""
18
from __future__ import print_function
19
import tensorflow.compat.v1 as tf
20
from tensorflow.contrib import layers as contrib_layers
21

22

23
class MLP(object):
24
  """Definition of MLP Networks."""
25

26
  def __init__(self, keep_prob, wd, feature_dim):
27
    """Creates a model for classifying using MLP encoding.
28

29
    Args:
30
      keep_prob: The rate of keeping one neuron in Dropout.
31
      wd: The co-efficient of weight decay.
32
      feature_dim: the dimension of the representation space.
33
    """
34
    super(MLP, self).__init__()
35

36
    self.regularizer = contrib_layers.l2_regularizer(scale=wd)
37
    self.initializer = contrib_layers.xavier_initializer()
38
    self.variance_initializer = contrib_layers.variance_scaling_initializer(
39
        factor=0.1,
40
        mode='FAN_IN',
41
        uniform=False,
42
        seed=None,
43
        dtype=tf.dtypes.float32)
44
    self.drop_rate = 1 - keep_prob
45
    self.feature_dim = feature_dim
46

47
  def encoder(self, inputs, training):
48
    """Forwards a batch of inputs.
49

50
    Args:
51
      inputs: A Tensor representing a batch of inputs.
52
      training: A boolean. Set to True to add operations required only when
53
        training the classifier.
54

55
    Returns:
56
      A logits Tensor. If self.neck is true, the logits Tensor is with shape
57
      [<batch_size>, self.num_classes]. If self.neck is not true, the logits
58
      Tensor is with shape [<batch_size>, 256].
59
    """
60
    # pylint: disable=unused-argument
61
    out = tf.layers.dense(
62
        inputs,
63
        units=self.feature_dim,
64
        kernel_initializer=self.initializer,
65
        kernel_regularizer=self.regularizer,
66
        name='fc')
67
    out = tf.nn.relu(out)
68
    out = tf.layers.dense(
69
        out,
70
        units=self.feature_dim,
71
        kernel_initializer=self.initializer,
72
        kernel_regularizer=self.regularizer,
73
        name='fc2')
74
    return out
75

76
  def confidence_model(self, mu, training):
77
    """Given a batch of mu, output a batch of variance."""
78
    out = tf.layers.dropout(mu, rate=self.drop_rate, training=training)
79
    out = tf.layers.dense(
80
        out,
81
        units=self.feature_dim,
82
        kernel_initializer=self.initializer,
83
        kernel_regularizer=self.regularizer,
84
        name='fc_variance')
85
    out = tf.nn.relu(out)
86
    out = tf.layers.dropout(out, rate=self.drop_rate, training=training)
87
    out = tf.layers.dense(
88
        out,
89
        units=self.feature_dim,
90
        kernel_initializer=self.initializer,
91
        kernel_regularizer=self.regularizer,
92
        name='fc_variance2')
93
    return out
94

95

96
def mlp(keep_prob, wd, feature_dim):
97
  net = MLP(keep_prob=keep_prob, wd=wd, feature_dim=feature_dim)
98
  return net
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.