google-research

Форк
0
426 строк · 15.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: disable=logging-format-interpolation
17
# pylint: disable=g-long-lambda
18
# pylint: disable=logging-not-lazy
19
# pylint: disable=protected-access
20

21
r"""Data."""
22

23
import os
24

25
from absl import logging
26

27
import numpy as np
28

29
import tensorflow.compat.v1 as tf  # tf
30

31
from differentiable_data_selection import augment
32

33
CIFAR_PATH = ''
34
CIFAR_MEAN = np.array([0.491400, 0.482158, 0.4465309], np.float32) * 255.
35
CIFAR_STDDEV = np.array([0.247032, 0.243485, 0.26159], np.float32) * 255.
36

37

38

39
################################################################################
40
#                                                                              #
41
# OUTSIDE INTERFACE                                                            #
42
#                                                                              #
43
################################################################################
44

45

46
def _dataset_service(params, dataset, start_index=None, final_index=None):
47
  """Wrap `dataset` into `dataset_service`."""
48
  return dataset
49

50

51
def get_image_mean_and_std(params):
52
  """Builds `eval_data` depends on `params.dataset_name`."""
53
  if params.dataset_name.lower().startswith('cifar'):
54
    return CIFAR_MEAN, CIFAR_STDDEV
55
  else:
56
    raise ValueError(f'Unknown dataset_name `{params.dataset_name}`')
57

58

59
def convert_and_normalize(params, images):
60
  """Subtract mean and divide stddev depending on the dataset."""
61
  dtype = tf.bfloat16 if params.use_bfloat16 else tf.float32
62
  if 'cifar' in params.dataset_name.lower():
63
    images = tf.cast(images, dtype)
64
  else:
65
    images = tf.image.convert_image_dtype(images, dtype)
66
  shape = [1, 1, 1, 3] if len(images.shape.as_list()) == 4 else [1, 1, 3]
67

68
  mean, stddev = get_image_mean_and_std(params)
69
  mean = tf.reshape(tf.cast(mean, images.dtype), shape)
70
  stddev = tf.reshape(tf.cast(stddev, images.dtype), shape)
71
  images = (images - mean) / stddev
72

73
  return images
74

75

76
def get_eval_size(params):
77
  """Builds `eval_data` depends on `params.dataset_name`."""
78
  eval_sizes = {
79
      'cifar10': 10000,
80
      'cifar10_dds': 10000,
81
  }
82
  if params.dataset_name.lower() not in eval_sizes.keys():
83
    raise ValueError(f'Unknown dataset_name `{params.dataset_name}`')
84
  eval_size = eval_sizes[params.dataset_name.lower()]
85
  return compute_num_padded_data(params, eval_size)
86

87

88
def build_eval_dataset(params,
89
                       batch_size=None,
90
                       num_workers=None,
91
                       worker_index=None):
92
  """Builds `eval_data` depends on `params.dataset_name`."""
93
  if params.dataset_name.lower() in ['cifar10', 'cifar10_dds']:
94
    eval_data = cifar10_eval(
95
        params, batch_size=batch_size, eval_mode='test',
96
        num_workers=num_workers, worker_index=worker_index)
97
  else:
98
    raise ValueError(f'Unknown dataset_name `{params.dataset_name}`')
99

100
  return eval_data
101

102

103
def build_train_infeeds(params):
104
  """Create the TPU infeed ops."""
105
  dev_assign = params.device_assignment
106
  host_to_tpus = {}
107
  for replica_id in range(params.num_replicas):
108
    host_device = dev_assign.host_device(replica=replica_id, logical_core=0)
109
    tpu_ordinal = dev_assign.tpu_ordinal(replica=replica_id, logical_core=0)
110
    logging.info(f'replica_id={replica_id} '
111
                 f'host_device={host_device} '
112
                 f'tpu_ordinal={tpu_ordinal}')
113

114
    if host_device not in host_to_tpus:
115
      host_to_tpus[host_device] = [tpu_ordinal]
116
    else:
117
      assert tpu_ordinal not in host_to_tpus[host_device]
118
      host_to_tpus[host_device].append(tpu_ordinal)
119

120
  infeed_ops = []
121
  infeed_graphs = []
122
  num_inputs = len(host_to_tpus)
123
  for i, (host, tpus) in enumerate(host_to_tpus.items()):
124
    infeed_graph = tf.Graph()
125
    infeed_graphs.append(infeed_graph)
126
    with infeed_graph.as_default():
127
      def enqueue_fn(host_device=host, input_index=i, device_ordinals=tpus):
128
        """Docs."""
129
        worker_infeed_ops = []
130
        with tf.device(host_device):
131
          dataset = build_train_dataset(
132
              params=params,
133
              batch_size=params.train_batch_size // num_inputs,
134
              num_inputs=num_inputs,
135
              input_index=input_index)
136
          inputs = tf.data.make_one_shot_iterator(dataset).get_next()
137

138
          num_splits = len(device_ordinals)
139
          if len(device_ordinals) > 1:
140
            inputs = [tf.split(v, num_splits, 0) for v in inputs]
141
          else:
142
            inputs = [[v] for v in inputs]
143
          input_dtypes = [v[0].dtype for v in inputs]
144
          input_shapes = [v[0].shape for v in inputs]
145
          params.add_hparam('train_dtypes', input_dtypes)
146
          params.add_hparam('train_shapes', input_shapes)
147
          for j, device_ordinal in enumerate(device_ordinals):
148
            worker_infeed_ops.append(tf.raw_ops.InfeedEnqueueTuple(
149
                inputs=[v[j] for v in inputs],
150
                shapes=input_shapes,
151
                device_ordinal=device_ordinal))
152
        return worker_infeed_ops
153
      def _body(i):
154
        with tf.control_dependencies(enqueue_fn()):
155
          return i+1
156
      infeed_op = tf.while_loop(
157
          lambda step: tf.less(step, tf.cast(params.save_every, step.dtype)),
158
          _body, [0], parallel_iterations=1, name='train_infeed').op
159
      infeed_ops.append(infeed_op)
160

161
  return infeed_ops, infeed_graphs
162

163

164
def build_eval_infeeds(params):
165
  """Create the TPU infeed ops."""
166

167
  eval_size = get_eval_size(params)
168
  num_eval_steps = eval_size // params.eval_batch_size
169

170
  dev_assign = params.device_assignment
171
  host_to_tpus = {}
172
  for replica_id in range(params.num_replicas):
173
    host_device = dev_assign.host_device(replica=replica_id, logical_core=0)
174
    tpu_ordinal = dev_assign.tpu_ordinal(replica=replica_id, logical_core=0)
175

176
    if host_device not in host_to_tpus:
177
      host_to_tpus[host_device] = [tpu_ordinal]
178
    else:
179
      assert tpu_ordinal not in host_to_tpus[host_device]
180
      host_to_tpus[host_device].append(tpu_ordinal)
181

182
  infeed_ops = []
183
  infeed_graphs = []
184
  num_inputs = len(host_to_tpus)
185
  for i, (host, tpus) in enumerate(host_to_tpus.items()):
186
    infeed_graph = tf.Graph()
187
    infeed_graphs.append(infeed_graph)
188
    with infeed_graph.as_default():
189
      def enqueue_fn(host_device=host, input_index=i, device_ordinals=tpus):
190
        """Docs."""
191
        worker_infeed_ops = []
192
        with tf.device(host_device):
193
          dataset = build_eval_dataset(
194
              params,
195
              batch_size=params.eval_batch_size // num_inputs,
196
              num_workers=num_inputs,
197
              worker_index=input_index)
198
          inputs = tf.data.make_one_shot_iterator(dataset).get_next()
199

200
          num_splits = len(device_ordinals)
201
          if len(device_ordinals) > 1:
202
            inputs = [tf.split(v, num_splits, 0) for v in inputs]
203
          else:
204
            inputs = [[v] for v in inputs]
205
          input_dtypes = [v[0].dtype for v in inputs]
206
          input_shapes = [v[0].shape for v in inputs]
207
          params.add_hparam('eval_dtypes', input_dtypes)
208
          params.add_hparam('eval_shapes', input_shapes)
209
          for j, device_ordinal in enumerate(device_ordinals):
210
            worker_infeed_ops.append(tf.raw_ops.InfeedEnqueueTuple(
211
                inputs=[v[j] for v in inputs],
212
                shapes=input_shapes,
213
                device_ordinal=device_ordinal))
214
        return worker_infeed_ops
215
      def _body(i):
216
        with tf.control_dependencies(enqueue_fn()):
217
          return i+1
218
      infeed_op = tf.while_loop(
219
          lambda step: tf.less(step, tf.cast(num_eval_steps, step.dtype)),
220
          _body, [0], parallel_iterations=1, name='eval_infeed').op
221
      infeed_ops.append(infeed_op)
222

223
  return infeed_ops, infeed_graphs, eval_size
224

225

226
def build_train_dataset(params, batch_size=None, num_inputs=1, input_index=0):
227
  """Builds `train_data` and `eval_data` depends on `params.dataset_name`."""
228
  del num_inputs
229
  del input_index
230
  if params.dataset_name.lower() == 'cifar10':
231
    dataset = cifar10_train(params, batch_size)
232
    dataset = _dataset_service(params, dataset)
233
  elif params.dataset_name.lower() == 'cifar10_dds':
234
    dataset = cifar10_dds(params, batch_size)
235
    dataset = _dataset_service(params, dataset)
236
  else:
237
    raise ValueError(f'Unknown dataset_name `{params.dataset_name}`')
238

239
  return dataset
240

241

242
def _smallest_mul(a, m):
243
  return (a + m-1) // m * m
244

245

246
def compute_num_padded_data(params, data_size):
247
  """Compute number of eval steps."""
248
  return (_smallest_mul(data_size, params.eval_batch_size)
249
          + params.eval_batch_size)
250

251

252
def _add_sample_weight(params, dataset, num_images):
253
  """Maps a `Dataset` of `d_0, d_1, ...` to compute its `sample_weights`."""
254

255
  if 'eval_image_size' in params:
256
    image_size = max(params.image_size, params.eval_image_size)
257
  else:
258
    image_size = params.image_size
259

260
  dtype = tf.bfloat16 if params.use_bfloat16 else tf.float32
261
  dummy_dataset = tf.data.Dataset.from_tensors((
262
      tf.zeros([image_size, image_size, 3], dtype=dtype),
263
      tf.zeros([params.num_classes], dtype=tf.float32),
264
      tf.constant(0., dtype=tf.float32),
265
  )).repeat()
266

267
  def _transform(images, labels):
268
    return images, labels, tf.constant(1., dtype=tf.float32)
269
  dataset = dataset.map(_transform, tf.data.experimental.AUTOTUNE)
270
  dataset = dataset.concatenate(dummy_dataset)
271
  dataset = dataset.take(num_images // params.num_workers).cache().repeat()
272

273
  return dataset
274

275

276
def _optimize_dataset(dataset):
277
  """Routines to optimize `Dataset`'s speed."""
278
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
279
  options = tf.data.Options()
280
  options.experimental_optimization.parallel_batch = True
281
  options.experimental_optimization.map_fusion = True
282
  options.experimental_optimization.map_parallelization = True
283
  dataset = dataset.with_options(options)
284
  dataset = dataset.prefetch(1)
285
  return dataset
286

287

288
def _flip_and_jitter(x, replace_value=0):
289
  """Flip left/right and jitter."""
290
  x = tf.image.random_flip_left_right(x)
291
  image_size = min([x.shape[0], x.shape[1]])
292
  pad_size = image_size // 8
293
  x = tf.pad(x, paddings=[[pad_size, pad_size], [pad_size, pad_size], [0, 0]],
294
             constant_values=replace_value)
295
  x = tf.image.random_crop(x, [image_size, image_size, 3])
296
  x.set_shape([image_size, image_size, 3])
297
  return x
298

299

300
def _jitter(x, replace_value=0):
301
  """Flip left/right and jitter."""
302
  image_size = min([x.shape[0], x.shape[1]])
303
  pad_size = image_size // 8
304
  x = tf.pad(x, paddings=[[pad_size, pad_size], [pad_size, pad_size], [0, 0]],
305
             constant_values=replace_value)
306
  x = tf.image.random_crop(x, [image_size, image_size, 3])
307
  x.set_shape([image_size, image_size, 3])
308
  return x
309

310

311
################################################################################
312
#                                                                              #
313
# CIFAR-10                                                                     #
314
#                                                                              #
315
################################################################################
316

317

318
def _cifar10_parser(params, value, training):
319
  """Cifar10 parser."""
320
  image_size = params.image_size
321
  value = tf.io.decode_raw(value, tf.uint8)
322
  label = tf.cast(value[0], tf.int32)
323
  label = tf.one_hot(label, depth=params.num_classes, dtype=tf.float32)
324
  image = tf.reshape(value[1:], [3, 32, 32])  # uint8
325
  image = tf.transpose(image, [1, 2, 0])
326
  if image_size != 32:
327
    image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
328
  image.set_shape([image_size, image_size, 3])
329

330
  if training:
331
    if params.use_augment:
332
      aug = augment.RandAugment(cutout_const=image_size//8,
333
                                translate_const=image_size//8,
334
                                magnitude=params.augment_magnitude)
335
      image = _flip_and_jitter(image, 128)
336
      image = aug.distort(image)
337
      image = augment.cutout(image, pad_size=image_size//4, replace=128)
338
    else:
339
      image = _flip_and_jitter(image, 128)
340
  image = convert_and_normalize(params, image)
341
  return image, label
342

343

344
def cifar10_train(params, batch_size=None):
345
  """Load CIFAR-10 data."""
346
  shuffle_size = batch_size * 16
347

348
  filenames = [os.path.join(CIFAR_PATH, 'train.bin')]
349
  record_bytes = 1 + (3 * 32 * 32)
350
  dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)
351
  dataset = dataset.skip(5000).cache()
352
  dataset = dataset.map(
353
      lambda x: _cifar10_parser(params, x, training=True),
354
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
355
  dataset = dataset.shuffle(shuffle_size).repeat()
356
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
357
  dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
358
  dataset = _optimize_dataset(dataset)
359

360
  return dataset
361

362

363
def cifar10_dds(params, batch_size=None):
364
  """Load CIFAR-10 data."""
365
  shuffle_size = batch_size * 16
366

367
  filenames = [os.path.join(CIFAR_PATH, 'train.bin')]
368
  record_bytes = 1 + (3 * 32 * 32)
369
  dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)
370

371
  train_dataset = dataset.skip(5000).cache()
372
  train_dataset = train_dataset.map(
373
      lambda x: _cifar10_parser(params, x, training=True),
374
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
375
  train_dataset = train_dataset.shuffle(shuffle_size).repeat()
376
  train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
377
  train_dataset = train_dataset.batch(batch_size=batch_size,
378
                                      drop_remainder=True)
379

380
  valid_dataset = dataset.take(5000).cache()
381
  valid_dataset = valid_dataset.map(
382
      lambda x: _cifar10_parser(params, x, training=False),
383
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
384
  valid_dataset = valid_dataset.shuffle(shuffle_size).repeat()
385
  valid_dataset = valid_dataset.prefetch(tf.data.experimental.AUTOTUNE)
386
  valid_dataset = valid_dataset.batch(batch_size=batch_size,
387
                                      drop_remainder=True)
388

389
  dataset = tf.data.Dataset.zip((train_dataset, valid_dataset))
390
  dataset = dataset.map(lambda a, b: tuple([a[0], a[1], b[0], b[1]]),
391
                        tf.data.experimental.AUTOTUNE)
392

393
  dataset = _optimize_dataset(dataset)
394
  return dataset
395

396

397
def cifar10_eval(params, batch_size=None, eval_mode=None,
398
                 num_workers=None, worker_index=None):
399
  """Load CIFAR-10 data."""
400

401
  if batch_size is None:
402
    batch_size = params.eval_batch_size
403

404
  if eval_mode == 'valid':
405
    filenames = [os.path.join(CIFAR_PATH, 'val.bin')]
406
  elif eval_mode == 'test':
407
    filenames = [os.path.join(CIFAR_PATH, 'test_batch.bin')]
408
  else:
409
    raise ValueError(f'Unknown eval_mode {eval_mode}')
410

411
  record_bytes = 1 + (3 * 32 * 32)
412
  dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)
413
  if num_workers is not None and worker_index is not None:
414
    dataset = dataset.shard(num_workers, worker_index)
415

416
  dataset = dataset.map(
417
      lambda x: _cifar10_parser(params, x, training=False),
418
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
419
  dataset = _add_sample_weight(params, dataset,
420
                               num_images=get_eval_size(params))
421

422
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
423
  dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
424
  dataset = _optimize_dataset(dataset)
425

426
  return dataset
427

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.