google-research

Форк
0
/
redace_input_pipeline.py 
119 строк · 3.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Input function utils."""
17

18
import tensorflow as tf
19

20

21
def _decode_record(record, name_to_features):
22
  """Decodes a record to a TensorFlow example."""
23
  example = tf.io.parse_single_example(record, name_to_features)
24

25
  # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
26
  # So cast all int64 to int32.
27
  for name in list(example.keys()):
28
    t = example[name]
29
    if t.dtype == tf.int64:
30
      t = tf.cast(t, tf.int32)
31
    example[name] = t
32

33
  return example
34

35

36
def create_redace_dataset(
37
    input_file,
38
    seq_length,
39
    batch_size=512,
40
    is_training=True,
41
    use_weighted_labels=True,
42
):
43
  """Creates input dataset from tfrecords files for RED-ACE model.
44

45
  Args:
46
    input_file: Input file path.
47
    seq_length: Maximum length of sequence.
48
    batch_size: Size of batch.
49
    is_training: Whether dataset is used for training.
50
    use_weighted_labels: Whether different labels were given different weights.
51
      Primarly used to increase the importance of rare tags.
52

53
  Returns:
54
    tensorflow dataset.
55
  """
56
  tagging_name_to_features = {
57
      'input_ids':
58
          tf.io.FixedLenFeature([seq_length], tf.int64),
59
      'input_mask':
60
          tf.io.FixedLenFeature([seq_length], tf.int64),
61
      'segment_ids':
62
          tf.io.FixedLenFeature([seq_length], tf.int64),
63
      'bucketed_confidence_scores':
64
          tf.io.FixedLenFeature([seq_length], tf.int64),
65
      'labels':
66
          tf.io.FixedLenFeature([seq_length], tf.int64),
67
  }
68

69
  tagging_name_to_features['labels_mask'] = tf.io.FixedLenFeature([seq_length],
70
                                                                  tf.float32)
71

72
  name_to_features = tagging_name_to_features
73

74
  d = tf.data.Dataset.from_tensor_slices(tf.constant([input_file]))
75
  dataset = d.interleave(
76
      tf.data.TFRecordDataset,
77
      cycle_length=1,
78
      num_parallel_calls=tf.data.experimental.AUTOTUNE,
79
  ).repeat()
80

81
  if is_training:
82
    dataset = dataset.shuffle(buffer_size=min(1, 100))
83
  options = tf.data.Options()
84
  options.experimental_distribute.auto_shard_policy = (
85
      tf.data.experimental.AutoShardPolicy.DATA)
86
  options.experimental_deterministic = not is_training
87
  dataset = dataset.with_options(options)
88
  decode_fn = lambda record: _decode_record(record, name_to_features)
89
  dataset = dataset.map(
90
      decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
91

92
  def _select_data_from_record(record):
93
    """Filter out features to use for pretraining."""
94
    x = {
95
        'input_word_ids': record['input_ids'],
96
        'input_mask': record['input_mask'],
97
        'input_type_ids': record['segment_ids'],
98
        'input_confidence_scores': record['bucketed_confidence_scores'],
99
        'edit_tags': record['labels'],
100
    }
101
    if use_weighted_labels and 'labels_mask' in record:
102
      x['labels_mask'] = record['labels_mask']
103
    else:
104
      x['labels_mask'] = record['input_mask']
105

106
    y = record['input_ids']
107

108
    return (x, y)
109

110
  dataset = dataset.map(
111
      _select_data_from_record,
112
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
113

114
  if is_training:
115
    dataset = dataset.shuffle(50000)
116

117
  dataset = dataset.batch(batch_size, drop_remainder=True)
118
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
119
  return dataset
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.