google-research
70 строк · 2.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Common weight initializers used in the sparse transformer."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23import tensorflow.compat.v1 as tf
24
25from tensorflow.python.framework import dtypes # pylint: disable=g-direct-tensorflow-import
26from tensorflow.python.ops import init_ops # pylint: disable=g-direct-tensorflow-import
27
28
29class SparseGlorotUniform(init_ops.Initializer):
30"""Re-weighted glorot uniform initializer based on sparsity."""
31
32def __init__(self, sparsity, seed=None, dtype=tf.float32):
33if sparsity < 0.0 or sparsity >= 1.0:
34raise ValueError("sparsity must be in range [0.0, 1.0).")
35
36self.sparsity = sparsity
37self.seed = seed
38self.dtype = init_ops._assert_float_dtype( # pylint: disable=protected-access
39dtypes.as_dtype(dtype))
40
41def __call__(self, shape, dtype=None, partition_info=None):
42if partition_info is not None:
43raise ValueError("partition_info not supported.")
44if dtype is None:
45dtype = self.dtype
46
47if len(shape) != 2:
48raise ValueError("Weights must be 2-dimensional.")
49
50fan_in, fan_out = init_ops._compute_fans(shape) # pylint: disable=protected-access
51
52# Calculate the number of non-zero weights in the weight matrix
53nnz = 1.0
54for d in shape:
55nnz *= d
56nnz *= 1 - self.sparsity
57
58limit = math.sqrt(6.0 / (nnz / fan_out + nnz / fan_in))
59return tf.random_uniform(
60shape,
61-limit,
62limit,
63dtype,
64seed=self.seed)
65
66def get_config(self):
67return {
68"seed": self.seed,
69"dtype": self.dtype.name
70}
71