google-research
69 строк · 2.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Common layers used in the sparse transformer."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23from tensor2tensor.layers import common_layers
24import tensorflow.compat.v1 as tf
25
26from state_of_sparsity.sparse_transformer.layers import common_sparse
27
28
29def dense_relu_dense(inputs,
30filter_size,
31output_size,
32output_activation=None,
33dropout=0.0,
34dropout_broadcast_dims=None,
35sparsity_technique=None,
36threshold=3.0,
37clip_alpha=None,
38training=True,
39name=None,
40initial_sparsity=None):
41"""Hidden layer with RELU activation followed by linear projection."""
42layer_fn = common_layers.dense
43if sparsity_technique:
44layer_fn = functools.partial(
45common_sparse.dense,
46sparsity_technique=sparsity_technique,
47threshold=threshold,
48training=training,
49clip_alpha=clip_alpha,
50initial_sparsity=initial_sparsity)
51
52layer_name = "%s_{}" % name if name else "{}"
53h = layer_fn(
54inputs,
55filter_size,
56use_bias=True,
57activation=tf.nn.relu,
58name=layer_name.format("conv1"))
59
60if dropout != 0.0:
61h = common_layers.dropout_with_broadcast_dims(
62h, 1.0 - dropout, broadcast_dims=dropout_broadcast_dims)
63o = layer_fn(
64h,
65output_size,
66activation=output_activation,
67use_bias=True,
68name=layer_name.format("conv2"))
69return o
70