google-research

Форк
0
73 строки · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Data pipeline."""
17

18
import ml_collections
19
import tensorflow as tf
20

21

22
def get_datasets(
23
    config,
24
    data_config,
25
    batch_size,
26
    repeat = False,
27
):
28
  """Construct tf datasets given configs.
29

30
  Args:
31
    config: top level config.
32
    data_config: data specific config.
33
    batch_size: batch size to use for training.
34
    repeat: whether to repeat the dataset indefinitely.
35

36
  Returns:
37
    train_ds, test_ds: dataset objects.
38
  """
39
  if config.tuning_mode:
40
    # We do hparam tuning on 5% of training set. Assumes 20 shards.
41
    train_example_paths = tf.io.gfile.glob(data_config.train_example_path)
42
    train_example_paths = train_example_paths[1:]
43
    test_example_paths = train_example_paths[:1]
44
  else:
45
    train_example_paths = tf.io.gfile.glob(data_config.train_example_path)
46
    test_example_paths = tf.io.gfile.glob(data_config.test_example_path)
47

48
  def decode_fn(record_bytes):
49
    return tf.io.parse_single_example(
50
        # Data
51
        record_bytes,
52

53
        # Schema
54
        {
55
            'repr': tf.io.FixedLenFeature(
56
                [data_config.hidden_dims], dtype=tf.float32
57
            ),
58
            'label': tf.io.FixedLenFeature([], dtype=tf.int64),
59
        },
60
    )
61

62
  test_ds = tf.data.TFRecordDataset(test_example_paths)
63
  test_ds = test_ds.map(decode_fn)
64
  test_ds = test_ds.batch(batch_size)
65
  test_ds = test_ds.prefetch(10)
66

67
  train_ds = tf.data.TFRecordDataset(train_example_paths)
68
  train_ds = train_ds.map(decode_fn)
69
  if repeat:
70
    train_ds = train_ds.repeat()
71
  train_ds = train_ds.batch(batch_size)
72
  train_ds = train_ds.prefetch(10)
73
  return train_ds, test_ds
74

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.