google-research

Форк
0
/
routing_transformer_test.py 
87 строк · 3.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for SM3 optimizer."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20
from tensor2tensor.utils import trainer_lib
21
import tensorflow.compat.v1 as tf
22
from tensorflow.compat.v1 import estimator as tf_estimator
23

24

25
from routing_transformer import sparse_transformer as sptf
26

27

28
class RoutingTransformerTest(tf.test.TestCase):
29

30
  def testSparseTransformer(self):
31
    """Test sparse transformer decode."""
32
    with self.cached_session() as sess:
33
      with tf.variable_scope("sparse_transformer", reuse=tf.AUTO_REUSE):
34
        hparams_set = "sparse_transformer_local"
35
        problem = ""
36
        hparams = trainer_lib.create_hparams(hparams_set, problem_name=problem)
37
        hparams.layer_prepostprocess_dropout = 0.
38
        hparams.dropout = 0.
39
        hparams.num_encoder_layers = 0
40
        hparams.num_decoder_layers = 2
41
        hparams.local_relative = False
42
        hparams.query_shape = (20,)
43
        hparams.memory_flange = (0,)
44
        hparams.max_length = 200
45
        sparse_transformer = sptf.SparseTransformer(hparams)
46
        sparse_transformer.set_mode(tf_estimator.ModeKeys.PREDICT)
47
        sparse_transformer.vocab_size = 50
48
        features = {}
49
        decode_step = 10
50
        cache = {}
51
        # Testing that changing target tokens beyond decode_step has no effect
52
        # i = 0 or less should have the next cell sum == 0
53
        i = -5
54
        targets_prefix = tf.random.stateless_uniform(
55
            [1, decode_step - i],
56
            minval=0,
57
            maxval=sparse_transformer.vocab_size,
58
            dtype=tf.dtypes.int32,
59
            seed=(75, 48))
60
        zeros = tf.zeros([1, hparams.max_length - decode_step + i],
61
                         dtype=tf.int32)
62
        features["targets"] = tf.concat([targets_prefix, zeros],
63
                                        axis=-1)
64
        output_step1 = sparse_transformer.body(features,
65
                                               decode_step=decode_step,
66
                                               cache=cache)
67
        features["targets"] = tf.concat([
68
            targets_prefix, tf.random.stateless_uniform(
69
                [1, hparams.max_length - decode_step + i],
70
                minval=0,
71
                maxval=sparse_transformer.vocab_size,
72
                dtype=tf.dtypes.int32,
73
                seed=(67, 89))], axis=-1)
74
        output_step2 = sparse_transformer.body(features,
75
                                               decode_step=decode_step,
76
                                               cache=cache)
77
        initializer = tf.global_variables_initializer()
78
        if initializer is not None:
79
          initializer.run()
80

81
        output1_np = sess.run(output_step1)
82
        output2_np = sess.run(output_step2)
83
        self.assertEqual(output1_np.shape, output2_np.shape)
84

85

86
if __name__ == "__main__":
87
  tf.test.main()
88

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.