google-research
322 строки · 11.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Provides a simple feed forward neural net class for regression tasks.
17
18The FeedForward class constructs a tensorflow model. The constructor takes a
19config object and expects metaparameter settings to be stored as config
20attributes. Any object that has the appropriate attributes can be passed in to
21configure the model since FeedForward does not assume the config object is
22anything more than a dumb struct and does not attempt to serialize it.
23Nonetheless, it will usually be an instance of tf.HParams. The model
24constructor sets up TF variables to hold the weights, but the fprop method
25builds the fprop graph for the model using the weights.
26
27A few notes on what metaparameters will be needed in general:
28
29For each fully connected (FC) layer, we need to select a size, an activation
30function, a dropout rate, and an initialiation scheme. At construction time, we
31only need the sizes and initialization.
32
33Right now there is no support for convolutional layers.
34
35Eventually, for each convolutional layer we need the activation function, a
36dropout rate, a filter size, a number of filters, an initialization scheme, and
37in principle padding and strides, but we will fix those. At construction time,
38we only need the filter size, number of filters, and initialization.
39
40Although in principle we can interleave FC and conv layers, life is complicated
41enough as it is. Let's do zero or more conv layers followed by zero or more FC
42layers. During metaparameter search, based on limitations of metaparater tuning
43policies, we will need to fix the number of layers of each type in a given
44study. We also might need to use introspection to add attributes to the hpconfig
45object the tuner gives us, since the tuner interface has limited flexibility
46for multi-dimensional metaparameters.
47
48The model class doesn't know anything about training, so training
49metaparameters in the config object will be ignored.
50"""
51
52import collections
53
54from six.moves import map
55
56import tensorflow.compat.v1 as tf
57from tensorflow.contrib import labeled_tensor as lt
58from xxx import layers as contrib_layers
59from xxx import framework as contrib_framework
60
61
62nonlinearities = {
63'relu': tf.nn.relu,
64'elu': tf.nn.elu,
65'tanh': tf.tanh,
66'sigmoid': tf.sigmoid
67}
68
69
70def _stack_inputs_by_rank(inputs):
71"""Create 2D and 3D input tensors from a dictionary of inputs.
72
733D inputs are stacked together for use in (optional) convolutional layers.
742D inputs are only used in fully-connected layers.
75
76Args:
77inputs: Dict[str, lt.LabeledTensor] providing input features. All features
78must be 2D or 3D labeled tensors with a 'batch' axis as their first
79dimension. 3D tensors must have 'position' as their second axis. The last
80axis of all tensors is allowed to vary, because raw input features may
81have different names for labels that are more meaningful than generic
82"features" or "channels".
83
84Returns:
85Tuple[Optional[lt.LabeledTensor], Optional[lt.LabeledTensor]], where the
86first labeled tensor, if present, has axes ['batch', 'feature'] and the
87second labeled tensor, if present, has axes ['batch', 'position',
88'channel'].
89
90Raises:
91ValueError: if the result tensors do not have the same batch axis.
92"""
93inputs_2d = []
94inputs_3d = []
95for key in sorted(inputs):
96# outputs should be fixed across randomized dict iteration order
97tensor = inputs[key]
98if len(tensor.axes) == 2:
99tensor = lt.rename_axis(tensor, list(tensor.axes.keys())[-1], 'feature')
100inputs_2d.append(tensor)
101elif len(tensor.axes) == 3:
102assert list(tensor.axes.values())[1].name == 'position'
103tensor = lt.rename_axis(tensor, list(tensor.axes.keys())[-1], 'channel')
104inputs_3d.append(tensor)
105else:
106raise AssertionError('unexpected rank')
107
108combined_2d = lt.concat(inputs_2d, 'feature') if inputs_2d else None
109combined_3d = lt.concat(inputs_3d, 'channel') if inputs_3d else None
110if combined_2d is not None and combined_3d is not None:
111if list(combined_2d.axes.values())[0] != list(combined_2d.axes.values())[0]:
112raise ValueError('mismatched batch axis')
113return combined_2d, combined_3d
114
115
116class FeedForward:
117"""Class implementing a simple feedforward neural net in tensorflow.
118
119Attributes:
120batch_axis: lt.Axis for batches of examples.
121input_position_axis: lt.Axis for input positions.
122input_channel_axis: lt.Axis for input channels.
123logit_axis: lt.Axis for logit channels, output from the `frop` method.
124config: a reference to the config object we used to specify the model. In
125general we expect it to be an instance of tf.HParams, but it could be
126anything with the right attributes.
127params: list of weights and biases
128"""
129
130def __init__(self, dummy_inputs, logit_axis, config):
131
132self.logit_axis = logit_axis
133self.config = config
134
135self.fc_sizes = getattr(config, 'fc_hid_sizes', []) + [len(logit_axis)]
136self.fc_init_factors = (
137getattr(config, 'fc_init_factors', []) + [config.output_init_factor])
138
139if not dummy_inputs:
140raise ValueError('network has size 0 input')
141if logit_axis.size == 0:
142raise ValueError('network has size 0 output')
143
144if len({
145len(self.fc_sizes), len(self.fc_init_factors), len(config.dropouts)
146}) != 1:
147raise ValueError('invalid hyperparameter config for fc layers')
148self.num_fc_layers = len(self.fc_sizes)
149
150self._conv_config = _ConvConfig(*[
151getattr(config, 'conv_' + field, []) for field in _ConvConfig._fields
152])
153if len(set(map(len, self._conv_config))) != 1:
154raise ValueError('invalid hyperparameter config for conv layers')
155self.num_conv_layers = len(self._conv_config.depths)
156
157self.fprop = tf.make_template('feedforward', self._fprop)
158# create variables
159self.fprop(dummy_inputs, mode='test')
160self.params = contrib_framework.get_variables(
161scope=self.fprop.variable_scope.name)
162
163def _fprop(self, inputs, mode):
164"""Builds the fprop graph from inputs up to logits.
165
166Args:
167inputs: input LabeledTensor with axes [batch_axis, input_position_axis,
168input_channel_axis].
169mode: either 'test' or 'train', determines whether we add dropout nodes
170
171Returns:
172Logits tensor with axes [batch_axis, logit_axis].
173
174Raises:
175ValueError: mode must be 'train' or 'test'
176"""
177if mode not in ['test', 'train']:
178raise ValueError('mode must be one of "train" or "test"')
179is_training = mode == 'train'
180
181inputs_2d, inputs_3d = _stack_inputs_by_rank(inputs)
182
183if inputs_2d is None and inputs_3d is None:
184raise ValueError('feedforward model has no inputs')
185
186# Get the batch axis from the actual inputs, because we set up the graph
187# with unknown batch size.
188example_inputs = inputs_3d if inputs_2d is None else inputs_2d
189batch_axis = example_inputs.axes['batch']
190
191w_initializer = tf.uniform_unit_scaling_initializer
192nonlinearity = nonlinearities[self.config.nonlinearity]
193
194if inputs_3d is not None:
195conv_args = list(zip(*self._conv_config))
196net = contrib_layers.stack(
197inputs_3d,
198conv1d,
199conv_args,
200scope='conv',
201padding='SAME',
202activation_fn=nonlinearity,
203w_initializer=w_initializer)
204net = contrib_layers.flatten(net)
205if inputs_2d is not None:
206net = tf.concat([net, inputs_2d], 1)
207else:
208net = inputs_2d
209
210if net.get_shape()[-1].value == 0:
211raise ValueError('feature dimension has size 0')
212
213keep_probs = [1 - d for d in self.config.dropouts]
214fc_args = list(zip(self.fc_sizes, keep_probs, self.fc_init_factors))
215
216net = contrib_layers.stack(
217net,
218dropout_and_fully_connected,
219fc_args[:-1],
220scope='fc',
221is_training=is_training,
222activation_fn=nonlinearity,
223w_initializer=w_initializer)
224
225# the last layer should not have a non-linearity
226net = dropout_and_fully_connected(
227net, *fc_args[-1], scope='fc_final', is_training=is_training,
228activation_fn=None, w_initializer=w_initializer)
229
230logits = lt.LabeledTensor(net, [batch_axis, self.logit_axis])
231return logits
232
233
234# must match the order of conv1d's arguments
235_ConvConfig = collections.namedtuple(
236'_ConvConfig', 'depths, widths, strides, rates, init_factors')
237
238
239def conv1d(inputs,
240filter_depth,
241filter_width,
242stride=1,
243rate=1,
244init_factor=1.0,
245w_initializer=None,
246**kwargs):
247"""Adds a convolutional 1d layer.
248
249If rate is 1 then a standard convolutional layer will be added,
250if rate is > 1 then an dilated (atrous) convolutional layer will
251be added.
252
253Args:
254inputs: a 3-D tensor `[batch_size, in_width, in_channels]`.
255filter_depth: integer, the number of output channels.
256filter_width: integer, size of the convolution kernel.
257stride: integer, size of the convolution stride.
258rate: integer, the size of the convolution dilation.
259init_factor: passed to `w_initializer`.
260w_initializer: function to call to create a weights initializer.
261**kwargs: passed on to `layers.conv2d`.
262
263Returns:
264A tensor variable representing the result of the series of operations.
265Raises:
266Error if rate > 1 and stride != 1. Current implementation of
267atrous_conv2d does not allow a stride other than 1.
268"""
269with tf.name_scope('conv1d'):
270# expand from 1d to 2d convolutions to match conv2d API
271# take inputs (which are only the inputs_3d layers) from
272# ['batch', 'position', 'channel'] to ['batch', 1, 'position', 'channel']
273# convolutions are done over the middle 2 dimensions.
274inputs_2d = tf.expand_dims(inputs, 1)
275kernel_size_2d = [1, filter_width]
276stride_2d = [1, stride]
277rate_2d = [1, rate]
278weights_initializer = w_initializer(factor=init_factor)
279output_2d = contrib_layers.conv2d(
280inputs_2d,
281filter_depth,
282kernel_size_2d,
283stride_2d,
284rate=rate_2d,
285weights_initializer=weights_initializer,
286**kwargs)
287
288output = tf.squeeze(output_2d, [1])
289return output
290
291
292def dropout_and_fully_connected(inputs,
293num_outputs,
294keep_prob=0.5,
295init_factor=1.0,
296is_training=True,
297w_initializer=None,
298**kwargs):
299"""Apply dropout followed by a fully connected layer.
300
301Args:
302inputs: A tensor of with at least rank 2 and value for the last dimension,
303i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
304num_outputs: Integer or long, the number of output units in the layer.
305keep_prob: A scalar `Tensor` with the same type as x. The probability
306that each element is kept.
307init_factor: passed to `w_initializer`.
308is_training: A bool `Tensor` indicating whether or not the model
309is in training mode. If so, dropout is applied and values scaled.
310Otherwise, dropout is skipped.
311w_initializer: Function to call to create a weights initializer.
312**kwargs: passed on to `layers.fully_connected`.
313
314Returns:
315A tensor variable representing the result of the series of operations.
316"""
317net = contrib_layers.dropout(
318inputs, keep_prob=keep_prob, is_training=is_training)
319weights_initializer = w_initializer(factor=init_factor)
320net = contrib_layers.fully_connected(
321net, num_outputs, weights_initializer=weights_initializer, **kwargs)
322return net
323