google-research
98 строк · 3.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contains definitions for MLP Networks.
17"""
18from __future__ import print_function
19import tensorflow.compat.v1 as tf
20from tensorflow.contrib import layers as contrib_layers
21
22
23class MLP(object):
24"""Definition of MLP Networks."""
25
26def __init__(self, keep_prob, wd, feature_dim):
27"""Creates a model for classifying using MLP encoding.
28
29Args:
30keep_prob: The rate of keeping one neuron in Dropout.
31wd: The co-efficient of weight decay.
32feature_dim: the dimension of the representation space.
33"""
34super(MLP, self).__init__()
35
36self.regularizer = contrib_layers.l2_regularizer(scale=wd)
37self.initializer = contrib_layers.xavier_initializer()
38self.variance_initializer = contrib_layers.variance_scaling_initializer(
39factor=0.1,
40mode='FAN_IN',
41uniform=False,
42seed=None,
43dtype=tf.dtypes.float32)
44self.drop_rate = 1 - keep_prob
45self.feature_dim = feature_dim
46
47def encoder(self, inputs, training):
48"""Forwards a batch of inputs.
49
50Args:
51inputs: A Tensor representing a batch of inputs.
52training: A boolean. Set to True to add operations required only when
53training the classifier.
54
55Returns:
56A logits Tensor. If self.neck is true, the logits Tensor is with shape
57[<batch_size>, self.num_classes]. If self.neck is not true, the logits
58Tensor is with shape [<batch_size>, 256].
59"""
60# pylint: disable=unused-argument
61out = tf.layers.dense(
62inputs,
63units=self.feature_dim,
64kernel_initializer=self.initializer,
65kernel_regularizer=self.regularizer,
66name='fc')
67out = tf.nn.relu(out)
68out = tf.layers.dense(
69out,
70units=self.feature_dim,
71kernel_initializer=self.initializer,
72kernel_regularizer=self.regularizer,
73name='fc2')
74return out
75
76def confidence_model(self, mu, training):
77"""Given a batch of mu, output a batch of variance."""
78out = tf.layers.dropout(mu, rate=self.drop_rate, training=training)
79out = tf.layers.dense(
80out,
81units=self.feature_dim,
82kernel_initializer=self.initializer,
83kernel_regularizer=self.regularizer,
84name='fc_variance')
85out = tf.nn.relu(out)
86out = tf.layers.dropout(out, rate=self.drop_rate, training=training)
87out = tf.layers.dense(
88out,
89units=self.feature_dim,
90kernel_initializer=self.initializer,
91kernel_regularizer=self.regularizer,
92name='fc_variance2')
93return out
94
95
96def mlp(keep_prob, wd, feature_dim):
97net = MLP(keep_prob=keep_prob, wd=wd, feature_dim=feature_dim)
98return net
99