google-research

Форк
0
69 строк · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Common layers used in the sparse transformer."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import functools
22

23
from tensor2tensor.layers import common_layers
24
import tensorflow.compat.v1 as tf
25

26
from state_of_sparsity.sparse_transformer.layers import common_sparse
27

28

29
def dense_relu_dense(inputs,
30
                     filter_size,
31
                     output_size,
32
                     output_activation=None,
33
                     dropout=0.0,
34
                     dropout_broadcast_dims=None,
35
                     sparsity_technique=None,
36
                     threshold=3.0,
37
                     clip_alpha=None,
38
                     training=True,
39
                     name=None,
40
                     initial_sparsity=None):
41
  """Hidden layer with RELU activation followed by linear projection."""
42
  layer_fn = common_layers.dense
43
  if sparsity_technique:
44
    layer_fn = functools.partial(
45
        common_sparse.dense,
46
        sparsity_technique=sparsity_technique,
47
        threshold=threshold,
48
        training=training,
49
        clip_alpha=clip_alpha,
50
        initial_sparsity=initial_sparsity)
51

52
  layer_name = "%s_{}" % name if name else "{}"
53
  h = layer_fn(
54
      inputs,
55
      filter_size,
56
      use_bias=True,
57
      activation=tf.nn.relu,
58
      name=layer_name.format("conv1"))
59

60
  if dropout != 0.0:
61
    h = common_layers.dropout_with_broadcast_dims(
62
        h, 1.0 - dropout, broadcast_dims=dropout_broadcast_dims)
63
  o = layer_fn(
64
      h,
65
      output_size,
66
      activation=output_activation,
67
      use_bias=True,
68
      name=layer_name.format("conv2"))
69
  return o
70

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.