google-research

Форк
0
70 строк · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Common weight initializers used in the sparse transformer."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import math
22

23
import tensorflow.compat.v1 as tf
24

25
from tensorflow.python.framework import dtypes  # pylint: disable=g-direct-tensorflow-import
26
from tensorflow.python.ops import init_ops  # pylint: disable=g-direct-tensorflow-import
27

28

29
class SparseGlorotUniform(init_ops.Initializer):
30
  """Re-weighted glorot uniform initializer based on sparsity."""
31

32
  def __init__(self, sparsity, seed=None, dtype=tf.float32):
33
    if sparsity < 0.0 or sparsity >= 1.0:
34
      raise ValueError("sparsity must be in range [0.0, 1.0).")
35

36
    self.sparsity = sparsity
37
    self.seed = seed
38
    self.dtype = init_ops._assert_float_dtype(  # pylint: disable=protected-access
39
        dtypes.as_dtype(dtype))
40

41
  def __call__(self, shape, dtype=None, partition_info=None):
42
    if partition_info is not None:
43
      raise ValueError("partition_info not supported.")
44
    if dtype is None:
45
      dtype = self.dtype
46

47
    if len(shape) != 2:
48
      raise ValueError("Weights must be 2-dimensional.")
49

50
    fan_in, fan_out = init_ops._compute_fans(shape)  # pylint: disable=protected-access
51

52
    # Calculate the number of non-zero weights in the weight matrix
53
    nnz = 1.0
54
    for d in shape:
55
      nnz *= d
56
    nnz *= 1 - self.sparsity
57

58
    limit = math.sqrt(6.0 / (nnz / fan_out + nnz / fan_in))
59
    return tf.random_uniform(
60
        shape,
61
        -limit,
62
        limit,
63
        dtype,
64
        seed=self.seed)
65

66
  def get_config(self):
67
    return {
68
        "seed": self.seed,
69
        "dtype": self.dtype.name
70
    }
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.