google-research
46 строк · 1.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for load_model."""
17
18from absl.testing import absltest19from absl.testing import parameterized20import numpy as np21import tensorflow as tf22from loss_functions_transfer import load_model23
24
25class LoadModelTest(parameterized.TestCase):26
27@parameterized.parameters(list(load_model.LOSS_HYPERPARAMETERS))28def test_build_and_restore(self, loss_name):29seed = 030inputs_np = np.zeros((1, 224, 224, 3))31labels_np = np.pad([[1]], ((0, 0), (1000, 0)))32
33with tf.Graph().as_default():34inputs = tf.compat.v1.placeholder(tf.float32, (None, 224, 224, 3))35labels = tf.compat.v1.placeholder(tf.float32, (None, 1001))36loss, endpoints = load_model.build_model_and_compute_loss(37loss_name=loss_name, inputs=inputs, labels=labels,38is_training=False)39with tf.compat.v1.Session() as sess:40load_model.restore_checkpoint(loss_name, seed, sess)41sess.run((loss, endpoints),42feed_dict={inputs: inputs_np, labels: labels_np})43
44
45if __name__ == '__main__':46absltest.main()47