google-research

Форк
0
/
datasets_test.py 
62 строки · 2.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for task_set.datasets."""
17

18
import numpy as np
19

20
from task_set import datasets
21
import tensorflow.compat.v1 as tf
22

23

24
class DatasetsTest(tf.test.TestCase):
25

26
  def test_split_dataset(self):
27
    slices = tf.constant(np.arange(100))
28
    dataset = tf.data.Dataset.from_tensor_slices(slices)
29
    d1, d2, d3 = datasets.split_dataset(dataset, num_splits=3, num_per_split=10)
30

31
    d1_example = d1.make_one_shot_iterator().get_next()
32
    d2_example = d2.make_one_shot_iterator().get_next()
33
    d3_example = d3.make_one_shot_iterator().get_next()
34

35
    with self.test_session() as sess:
36
      d1s = [sess.run(d1_example) for _ in range(10)]
37
      d2s = [sess.run(d2_example) for _ in range(10)]
38
      d3s = [sess.run(d3_example) for _ in range(80)]
39
      np.testing.assert_equal(d1s, np.arange(10))
40
      np.testing.assert_equal(d2s, np.arange(10) + 10)
41
      np.testing.assert_equal(d3s, np.arange(80) + 20)
42

43
    with self.assertRaises(tf.errors.OutOfRangeError):
44
      sess.run(d1_example)
45

46
    with self.assertRaises(tf.errors.OutOfRangeError):
47
      sess.run(d2_example)
48

49
    with self.assertRaises(tf.errors.OutOfRangeError):
50
      sess.run(d3_example)
51

52
  def test_food101_32x32(self):
53
    """Sanity check that food101 loads data."""
54
    dataset = datasets.tfds_load_dataset("food101_32x32", split="train")
55
    batch = dataset.make_one_shot_iterator().get_next()
56
    with self.cached_session() as sess:
57
      np_batch = sess.run(batch)
58
    self.assertEqual(tuple(np_batch["image"].shape), (32, 32, 3))
59

60

61
if __name__ == "__main__":
62
  tf.test.main()
63

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.