google-research
62 строки · 2.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for task_set.datasets."""
17
18import numpy as np19
20from task_set import datasets21import tensorflow.compat.v1 as tf22
23
24class DatasetsTest(tf.test.TestCase):25
26def test_split_dataset(self):27slices = tf.constant(np.arange(100))28dataset = tf.data.Dataset.from_tensor_slices(slices)29d1, d2, d3 = datasets.split_dataset(dataset, num_splits=3, num_per_split=10)30
31d1_example = d1.make_one_shot_iterator().get_next()32d2_example = d2.make_one_shot_iterator().get_next()33d3_example = d3.make_one_shot_iterator().get_next()34
35with self.test_session() as sess:36d1s = [sess.run(d1_example) for _ in range(10)]37d2s = [sess.run(d2_example) for _ in range(10)]38d3s = [sess.run(d3_example) for _ in range(80)]39np.testing.assert_equal(d1s, np.arange(10))40np.testing.assert_equal(d2s, np.arange(10) + 10)41np.testing.assert_equal(d3s, np.arange(80) + 20)42
43with self.assertRaises(tf.errors.OutOfRangeError):44sess.run(d1_example)45
46with self.assertRaises(tf.errors.OutOfRangeError):47sess.run(d2_example)48
49with self.assertRaises(tf.errors.OutOfRangeError):50sess.run(d3_example)51
52def test_food101_32x32(self):53"""Sanity check that food101 loads data."""54dataset = datasets.tfds_load_dataset("food101_32x32", split="train")55batch = dataset.make_one_shot_iterator().get_next()56with self.cached_session() as sess:57np_batch = sess.run(batch)58self.assertEqual(tuple(np_batch["image"].shape), (32, 32, 3))59
60
61if __name__ == "__main__":62tf.test.main()63