google-research

Форк
0
64 строки · 2.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for weight initializers."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import tensorflow.compat.v1 as tf
22

23
from state_of_sparsity.sparse_transformer.layers import common_init
24

25

26
class SparseGlorotUniformTest(tf.test.TestCase):
27

28
  def testSparseGlorotUniform_OutputShape(self):
29
    initializer = common_init.SparseGlorotUniform(.5)
30
    x = tf.get_variable(
31
        "x",
32
        shape=[512, 1024],
33
        initializer=initializer,
34
        dtype=tf.float32)
35
    with self.test_session() as sess:
36
      sess.run(tf.global_variables_initializer())
37
      res = sess.run(x)
38
    self.assertEqual(res.shape, (512, 1024))
39

40
  def testSparseGlorotUniform_NoSparsity(self):
41
    initializer = common_init.SparseGlorotUniform(0, seed=5)
42
    initializer_base = tf.glorot_uniform_initializer(seed=5)
43

44
    x = tf.get_variable(
45
        "x",
46
        shape=[512, 1024],
47
        initializer=initializer,
48
        dtype=tf.float32)
49
    y = tf.get_variable(
50
        "y",
51
        shape=[512, 1024],
52
        initializer=initializer_base,
53
        dtype=tf.float32)
54

55
    with self.test_session() as sess:
56
      sess.run(tf.global_variables_initializer())
57
      res_x = sess.run(x)
58
      res_y = sess.run(y)
59
    self.assertEqual(res_x.shape, (512, 1024))
60
    self.assertEqual(res_y.shape, (512, 1024))
61
    self.assertAllEqual(res_x, res_y)
62

63
if __name__ == "__main__":
64
  tf.test.main()
65

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.