google-research

Форк
0
222 строки · 6.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Dataset preprocessing and pipeline.
17

18
Built for Trembl dataset.
19
"""
20
import os
21
import types
22
from absl import logging
23
import gin
24
import numpy as np
25
import tensorflow.compat.v1 as tf
26

27

28
from protein_lm import domains
29

30

31
@gin.configurable
32
def make_protein_domain(include_anomalous_amino_acids=True,
33
                        include_bos=True,
34
                        include_eos=True,
35
                        include_pad=True,
36
                        include_mask=True,
37
                        length=1024):
38
  return domains.VariableLengthDiscreteDomain(
39
      vocab=domains.ProteinVocab(
40
          include_anomalous_amino_acids=include_anomalous_amino_acids,
41
          include_bos=include_bos,
42
          include_eos=include_eos,
43
          include_pad=include_pad,
44
          include_mask=include_mask),
45
      length=length,
46
  )
47

48

49
protein_domain = make_protein_domain()
50

51

52
def dataset_from_tensors(tensors):
53
  """Converts nested tf.Tensors or np.ndarrays to a tf.Data.Dataset."""
54
  if isinstance(tensors, types.GeneratorType) or isinstance(tensors, list):
55
    tensors = tuple(tensors)
56
  return tf.data.Dataset.from_tensor_slices(tensors)
57

58

59
def _parse_example(value):
60
  parsed = tf.parse_single_example(
61
      value, features={'sequence': tf.io.VarLenFeature(tf.int64)})
62
  sequence = tf.sparse.to_dense(parsed['sequence'])
63
  return sequence
64

65

66
@gin.configurable
67
def get_train_valid_files(directory, num_test_files=10, num_valid_files=1):
68
  """Given a directory, list files and split into train/test files.
69

70
  Args:
71
    directory: Directory containing data.
72
    num_test_files: Number of files to set aside for testing.
73
    num_valid_files: Number of files to use for validation.
74

75
  Returns:
76
    Tuple of lists of (train files, test files).
77
  """
78
  files = tf.gfile.ListDirectory(directory)
79
  files = [os.path.join(directory, f) for f in files if 'tmp' not in f]
80
  files = sorted(files)
81
  # Set aside the first num_test_files files for testing.
82
  valid_files = files[num_test_files:num_test_files + num_valid_files]
83
  train_files = files[num_test_files + num_valid_files:]
84
  return train_files, valid_files
85

86

87
def _add_eos(seq):
88
  """Add end of sequence markers."""
89
  # TODO(ddohan): Support non-protein domains.
90
  return tf.concat([seq, [protein_domain.vocab.eos]], axis=-1)
91

92

93
def load_dataset(train_files,
94
                 test_files,
95
                 shuffle_buffer=8192,
96
                 batch_size=32,
97
                 max_train_length=512,
98
                 max_eval_length=None):
99
  """Load data from directory.
100

101
  Takes first shard as test split.
102

103
  Args:
104
    train_files: Files to load training data from.
105
    test_files: Files to load test data from.
106
    shuffle_buffer: Shuffle buffer size for training.
107
    batch_size: Batch size.
108
    max_train_length: Length to crop train sequences to.
109
    max_eval_length: Length to crop eval sequences to.
110

111
  Returns:
112
    Tuple of (train dataset, test dataset)
113
  """
114
  max_eval_length = max_eval_length or max_train_length
115
  logging.info('Training on %s shards', len(train_files))
116
  print('Training on %s shards' % len(train_files))
117
  print('Test on %s shards' % str(test_files))
118

119
  test_ds = tf.data.TFRecordDataset(test_files)
120

121
  # Read training data from many files in parallel
122
  filenames_dataset = tf.data.Dataset.from_tensor_slices(train_files).shuffle(
123
      2048)
124
  train_ds = filenames_dataset.interleave(
125
      tf.data.TFRecordDataset,
126
      num_parallel_calls=tf.data.experimental.AUTOTUNE,
127
      deterministic=False)
128

129
  train_ds = train_ds.map(
130
      _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
131
  test_ds = test_ds.map(
132
      _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
133

134
  train_ds = batch_ds(
135
      train_ds,
136
      batch_size=batch_size,
137
      shuffle_buffer=shuffle_buffer,
138
      length=max_train_length)
139
  test_ds = batch_ds(
140
      test_ds,
141
      batch_size=batch_size,
142
      shuffle_buffer=None,
143
      length=max_eval_length)
144

145
  train_ds.prefetch(tf.data.experimental.AUTOTUNE)
146
  test_ds.prefetch(tf.data.experimental.AUTOTUNE)
147
  return train_ds, test_ds
148

149

150
@gin.configurable
151
def batch_ds(ds,
152
             length=512,
153
             batch_size=32,
154
             shuffle_buffer=8192,
155
             pack_length=None):
156
  """Crop, shuffle, and batch a dataset of sequences."""
157

158
  def _crop(x):
159
    return x[:length]
160

161
  if length:
162
    ds = ds.map(_crop, num_parallel_calls=tf.data.experimental.AUTOTUNE)
163
  if shuffle_buffer:
164
    ds = ds.shuffle(buffer_size=shuffle_buffer, reshuffle_each_iteration=True)
165

166
  if pack_length:
167
    logging.info('Packing sequences to length %s', pack_length)
168
    # Add EOS tokens.
169
    ds = ds.map(_add_eos, num_parallel_calls=tf.data.experimental.AUTOTUNE)
170

171
    # Pack sequences together by concatenating.
172
    ds = ds.unbatch()
173
    ds = ds.batch(pack_length)  # Pack length
174
    ds = ds.batch(batch_size, drop_remainder=True)  # Add batch dimension.
175
  else:
176
    ds = ds.padded_batch(
177
        batch_size,
178
        padded_shapes=length,
179
        padding_values=np.array(protein_domain.vocab.pad, dtype=np.int64),
180
        drop_remainder=True)
181
  return ds
182

183

184
def _encode_protein(protein_string):
185
  array = protein_domain.encode([protein_string], pad=False)
186
  array = np.array(array)
187
  return array
188

189

190
def _sequence_to_tf_example(sequence):
191
  sequence = np.array(sequence)
192
  features = {
193
      'sequence':
194
          tf.train.Feature(
195
              int64_list=tf.train.Int64List(value=sequence.reshape(-1))),
196
  }
197
  return tf.train.Example(features=tf.train.Features(feature=features))
198

199

200
def _write_tfrecord(sequences, outdir, idx, total):
201
  """Write iterable of sequences to sstable shard idx/total in outdir."""
202
  idx = '%0.5d' % idx
203
  total = '%0.5d' % total
204
  name = 'data-%s-of-%s' % (idx, total)
205
  path = os.path.join(outdir, name)
206
  with tf.io.TFRecordWriter(path) as writer:
207
    for seq in sequences:
208
      proto = _sequence_to_tf_example(seq)
209
      writer.write(proto.SerializeToString())
210

211

212
def csv_to_tfrecord(csv_path, outdir, idx, total):
213
  """Process csv at `csv_path` to shard idx/total in outdir."""
214
  with tf.gfile.GFile(csv_path) as f:
215

216
    def iterator():
217
      for line in f:
218
        _, seq = line.strip().split(',')
219
        yield _encode_protein(seq)
220

221
    it = iterator()
222
    _write_tfrecord(it, outdir, idx, total)
223

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.