google-research
87 строк · 3.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for SM3 optimizer."""
17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20from tensor2tensor.utils import trainer_lib21import tensorflow.compat.v1 as tf22from tensorflow.compat.v1 import estimator as tf_estimator23
24
25from routing_transformer import sparse_transformer as sptf26
27
28class RoutingTransformerTest(tf.test.TestCase):29
30def testSparseTransformer(self):31"""Test sparse transformer decode."""32with self.cached_session() as sess:33with tf.variable_scope("sparse_transformer", reuse=tf.AUTO_REUSE):34hparams_set = "sparse_transformer_local"35problem = ""36hparams = trainer_lib.create_hparams(hparams_set, problem_name=problem)37hparams.layer_prepostprocess_dropout = 0.38hparams.dropout = 0.39hparams.num_encoder_layers = 040hparams.num_decoder_layers = 241hparams.local_relative = False42hparams.query_shape = (20,)43hparams.memory_flange = (0,)44hparams.max_length = 20045sparse_transformer = sptf.SparseTransformer(hparams)46sparse_transformer.set_mode(tf_estimator.ModeKeys.PREDICT)47sparse_transformer.vocab_size = 5048features = {}49decode_step = 1050cache = {}51# Testing that changing target tokens beyond decode_step has no effect52# i = 0 or less should have the next cell sum == 053i = -554targets_prefix = tf.random.stateless_uniform(55[1, decode_step - i],56minval=0,57maxval=sparse_transformer.vocab_size,58dtype=tf.dtypes.int32,59seed=(75, 48))60zeros = tf.zeros([1, hparams.max_length - decode_step + i],61dtype=tf.int32)62features["targets"] = tf.concat([targets_prefix, zeros],63axis=-1)64output_step1 = sparse_transformer.body(features,65decode_step=decode_step,66cache=cache)67features["targets"] = tf.concat([68targets_prefix, tf.random.stateless_uniform(69[1, hparams.max_length - decode_step + i],70minval=0,71maxval=sparse_transformer.vocab_size,72dtype=tf.dtypes.int32,73seed=(67, 89))], axis=-1)74output_step2 = sparse_transformer.body(features,75decode_step=decode_step,76cache=cache)77initializer = tf.global_variables_initializer()78if initializer is not None:79initializer.run()80
81output1_np = sess.run(output_step1)82output2_np = sess.run(output_step2)83self.assertEqual(output1_np.shape, output2_np.shape)84
85
86if __name__ == "__main__":87tf.test.main()88