google-research
222 строки · 6.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Dataset preprocessing and pipeline.
17
18Built for Trembl dataset.
19"""
20import os21import types22from absl import logging23import gin24import numpy as np25import tensorflow.compat.v1 as tf26
27
28from protein_lm import domains29
30
31@gin.configurable32def make_protein_domain(include_anomalous_amino_acids=True,33include_bos=True,34include_eos=True,35include_pad=True,36include_mask=True,37length=1024):38return domains.VariableLengthDiscreteDomain(39vocab=domains.ProteinVocab(40include_anomalous_amino_acids=include_anomalous_amino_acids,41include_bos=include_bos,42include_eos=include_eos,43include_pad=include_pad,44include_mask=include_mask),45length=length,46)47
48
49protein_domain = make_protein_domain()50
51
52def dataset_from_tensors(tensors):53"""Converts nested tf.Tensors or np.ndarrays to a tf.Data.Dataset."""54if isinstance(tensors, types.GeneratorType) or isinstance(tensors, list):55tensors = tuple(tensors)56return tf.data.Dataset.from_tensor_slices(tensors)57
58
59def _parse_example(value):60parsed = tf.parse_single_example(61value, features={'sequence': tf.io.VarLenFeature(tf.int64)})62sequence = tf.sparse.to_dense(parsed['sequence'])63return sequence64
65
66@gin.configurable67def get_train_valid_files(directory, num_test_files=10, num_valid_files=1):68"""Given a directory, list files and split into train/test files.69
70Args:
71directory: Directory containing data.
72num_test_files: Number of files to set aside for testing.
73num_valid_files: Number of files to use for validation.
74
75Returns:
76Tuple of lists of (train files, test files).
77"""
78files = tf.gfile.ListDirectory(directory)79files = [os.path.join(directory, f) for f in files if 'tmp' not in f]80files = sorted(files)81# Set aside the first num_test_files files for testing.82valid_files = files[num_test_files:num_test_files + num_valid_files]83train_files = files[num_test_files + num_valid_files:]84return train_files, valid_files85
86
87def _add_eos(seq):88"""Add end of sequence markers."""89# TODO(ddohan): Support non-protein domains.90return tf.concat([seq, [protein_domain.vocab.eos]], axis=-1)91
92
93def load_dataset(train_files,94test_files,95shuffle_buffer=8192,96batch_size=32,97max_train_length=512,98max_eval_length=None):99"""Load data from directory.100
101Takes first shard as test split.
102
103Args:
104train_files: Files to load training data from.
105test_files: Files to load test data from.
106shuffle_buffer: Shuffle buffer size for training.
107batch_size: Batch size.
108max_train_length: Length to crop train sequences to.
109max_eval_length: Length to crop eval sequences to.
110
111Returns:
112Tuple of (train dataset, test dataset)
113"""
114max_eval_length = max_eval_length or max_train_length115logging.info('Training on %s shards', len(train_files))116print('Training on %s shards' % len(train_files))117print('Test on %s shards' % str(test_files))118
119test_ds = tf.data.TFRecordDataset(test_files)120
121# Read training data from many files in parallel122filenames_dataset = tf.data.Dataset.from_tensor_slices(train_files).shuffle(1232048)124train_ds = filenames_dataset.interleave(125tf.data.TFRecordDataset,126num_parallel_calls=tf.data.experimental.AUTOTUNE,127deterministic=False)128
129train_ds = train_ds.map(130_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)131test_ds = test_ds.map(132_parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)133
134train_ds = batch_ds(135train_ds,136batch_size=batch_size,137shuffle_buffer=shuffle_buffer,138length=max_train_length)139test_ds = batch_ds(140test_ds,141batch_size=batch_size,142shuffle_buffer=None,143length=max_eval_length)144
145train_ds.prefetch(tf.data.experimental.AUTOTUNE)146test_ds.prefetch(tf.data.experimental.AUTOTUNE)147return train_ds, test_ds148
149
150@gin.configurable151def batch_ds(ds,152length=512,153batch_size=32,154shuffle_buffer=8192,155pack_length=None):156"""Crop, shuffle, and batch a dataset of sequences."""157
158def _crop(x):159return x[:length]160
161if length:162ds = ds.map(_crop, num_parallel_calls=tf.data.experimental.AUTOTUNE)163if shuffle_buffer:164ds = ds.shuffle(buffer_size=shuffle_buffer, reshuffle_each_iteration=True)165
166if pack_length:167logging.info('Packing sequences to length %s', pack_length)168# Add EOS tokens.169ds = ds.map(_add_eos, num_parallel_calls=tf.data.experimental.AUTOTUNE)170
171# Pack sequences together by concatenating.172ds = ds.unbatch()173ds = ds.batch(pack_length) # Pack length174ds = ds.batch(batch_size, drop_remainder=True) # Add batch dimension.175else:176ds = ds.padded_batch(177batch_size,178padded_shapes=length,179padding_values=np.array(protein_domain.vocab.pad, dtype=np.int64),180drop_remainder=True)181return ds182
183
184def _encode_protein(protein_string):185array = protein_domain.encode([protein_string], pad=False)186array = np.array(array)187return array188
189
190def _sequence_to_tf_example(sequence):191sequence = np.array(sequence)192features = {193'sequence':194tf.train.Feature(195int64_list=tf.train.Int64List(value=sequence.reshape(-1))),196}197return tf.train.Example(features=tf.train.Features(feature=features))198
199
200def _write_tfrecord(sequences, outdir, idx, total):201"""Write iterable of sequences to sstable shard idx/total in outdir."""202idx = '%0.5d' % idx203total = '%0.5d' % total204name = 'data-%s-of-%s' % (idx, total)205path = os.path.join(outdir, name)206with tf.io.TFRecordWriter(path) as writer:207for seq in sequences:208proto = _sequence_to_tf_example(seq)209writer.write(proto.SerializeToString())210
211
212def csv_to_tfrecord(csv_path, outdir, idx, total):213"""Process csv at `csv_path` to shard idx/total in outdir."""214with tf.gfile.GFile(csv_path) as f:215
216def iterator():217for line in f:218_, seq = line.strip().split(',')219yield _encode_protein(seq)220
221it = iterator()222_write_tfrecord(it, outdir, idx, total)223