google-research
119 строк · 3.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Input function utils."""
17
18import tensorflow as tf19
20
21def _decode_record(record, name_to_features):22"""Decodes a record to a TensorFlow example."""23example = tf.io.parse_single_example(record, name_to_features)24
25# tf.Example only supports tf.int64, but the TPU only supports tf.int32.26# So cast all int64 to int32.27for name in list(example.keys()):28t = example[name]29if t.dtype == tf.int64:30t = tf.cast(t, tf.int32)31example[name] = t32
33return example34
35
36def create_redace_dataset(37input_file,38seq_length,39batch_size=512,40is_training=True,41use_weighted_labels=True,42):43"""Creates input dataset from tfrecords files for RED-ACE model.44
45Args:
46input_file: Input file path.
47seq_length: Maximum length of sequence.
48batch_size: Size of batch.
49is_training: Whether dataset is used for training.
50use_weighted_labels: Whether different labels were given different weights.
51Primarly used to increase the importance of rare tags.
52
53Returns:
54tensorflow dataset.
55"""
56tagging_name_to_features = {57'input_ids':58tf.io.FixedLenFeature([seq_length], tf.int64),59'input_mask':60tf.io.FixedLenFeature([seq_length], tf.int64),61'segment_ids':62tf.io.FixedLenFeature([seq_length], tf.int64),63'bucketed_confidence_scores':64tf.io.FixedLenFeature([seq_length], tf.int64),65'labels':66tf.io.FixedLenFeature([seq_length], tf.int64),67}68
69tagging_name_to_features['labels_mask'] = tf.io.FixedLenFeature([seq_length],70tf.float32)71
72name_to_features = tagging_name_to_features73
74d = tf.data.Dataset.from_tensor_slices(tf.constant([input_file]))75dataset = d.interleave(76tf.data.TFRecordDataset,77cycle_length=1,78num_parallel_calls=tf.data.experimental.AUTOTUNE,79).repeat()80
81if is_training:82dataset = dataset.shuffle(buffer_size=min(1, 100))83options = tf.data.Options()84options.experimental_distribute.auto_shard_policy = (85tf.data.experimental.AutoShardPolicy.DATA)86options.experimental_deterministic = not is_training87dataset = dataset.with_options(options)88decode_fn = lambda record: _decode_record(record, name_to_features)89dataset = dataset.map(90decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)91
92def _select_data_from_record(record):93"""Filter out features to use for pretraining."""94x = {95'input_word_ids': record['input_ids'],96'input_mask': record['input_mask'],97'input_type_ids': record['segment_ids'],98'input_confidence_scores': record['bucketed_confidence_scores'],99'edit_tags': record['labels'],100}101if use_weighted_labels and 'labels_mask' in record:102x['labels_mask'] = record['labels_mask']103else:104x['labels_mask'] = record['input_mask']105
106y = record['input_ids']107
108return (x, y)109
110dataset = dataset.map(111_select_data_from_record,112num_parallel_calls=tf.data.experimental.AUTOTUNE)113
114if is_training:115dataset = dataset.shuffle(50000)116
117dataset = dataset.batch(batch_size, drop_remainder=True)118dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)119return dataset120