google-research

Форк
0
321 строка · 11.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# pylint: skip-file
17
"""Utility code for generating and saving image grids and checkpointing.
18

19
   The save_image code is copied from
20
   https://github.com/google/flax/blob/master/examples/vae/utils.py,
21
   which is a JAX equivalent to the same function in TorchVision
22
   (https://github.com/pytorch/vision/blob/master/torchvision/utils.py)
23
"""
24

25
import collections
26
import math
27
import re
28
from typing import Any, Dict, Optional, TypeVar
29

30
from absl import logging
31
import flax
32
import jax
33
import jax.numpy as jnp
34
from PIL import Image
35
import tensorflow as tf
36

37
T = TypeVar("T")
38

39

40
def load_state_dict(filepath, state):
41
  with tf.io.gfile.GFile(filepath, "rb") as f:
42
    state = flax.serialization.from_bytes(state, f.read())
43
  return state
44

45

46
class CheckpointInfo(
47
    collections.namedtuple("CheckpointInfo", ("prefix", "number"))):
48
  """Helper class to parse a TensorFlow checkpoint path."""
49

50
  CHECKPOINT_REGEX = r"^(?P<prefix>.*)-(?P<number>\d+)"
51

52
  @classmethod
53
  def initialize(cls, base_directory, checkpoint_name):
54
    """Creates a first CheckpointInfo (number=1)."""
55
    return cls(f"{base_directory}/{checkpoint_name}", 1)
56

57
  @classmethod
58
  def from_path(cls, checkpoint):
59
    """Parses a checkpoint.
60

61
    Args:
62
      checkpoint: A checkpoint prefix, as can be found in the
63
        `.latest_checkpoint` property of a `tf.train.CheckpointManager`.
64

65
    Returns:
66
      An instance of `CheckpointInfo` that represents `checkpoint`.
67
    """
68
    m = re.match(cls.CHECKPOINT_REGEX, checkpoint)
69
    if m is None:
70
      RuntimeError(f"Invalid checkpoint format: {checkpoint}")
71
    d = m.groupdict()  # pytype: disable=attribute-error
72
    return cls(d["prefix"], int(d["number"]))
73

74
  def increment(self):
75
    """Returns a new CheckpointInfo with `number` increased by one."""
76
    return CheckpointInfo(self.prefix, self.number + 1)
77

78
  def __str__(self):
79
    """Does the opposite of `.from_path()`."""
80
    return f"{self.prefix}-{self.number}"
81

82

83
class Checkpoint:
84
  """A utility class for storing and loading TF2/Flax checkpoints.
85

86

87
  Both the state of a `tf.data.Dataset` iterator and a `flax.struct.dataclass`
88
  are stored on disk in the following files:
89

90
  - {directory}/checkpoint
91
  - {directory}/ckpt-{number}.index
92
  - {directory}/ckpt-{number}.data@*
93
  - {directory}/ckpt-{number}.flax
94

95
  Where {number} starts at 1 is then incremented by 1 for every new checkpoint.
96
  The last file is the `flax.struct.dataclass`, serialized in Messagepack
97
  format. The other files are explained in more detail in the Tensorflow
98
  documentation:
99

100
  https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint
101
  """
102

103
  def __init__(self,
104
               base_directory,
105
               tf_state = None,
106
               *,
107
               max_to_keep = None,
108
               checkpoint_name = "ckpt"):
109
    """Initializes a Checkpoint with a dictionary of TensorFlow Trackables.
110

111
    Args:
112
      base_directory: Directory under which the checkpoints will be stored. Use
113
        a different base_directory in every task.
114
      tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
115
        example a dataset iterator.
116
      max_to_keep: Number of checkpoints to keep in the directory. If there are
117
        more checkpoints than specified by this number, then the oldest
118
        checkpoints are removed.
119
      checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
120
    """
121
    if tf_state is None:
122
      tf_state = dict()
123
    self.base_directory = base_directory
124
    self.max_to_keep = max_to_keep
125
    self.checkpoint_name = checkpoint_name
126
    self.tf_checkpoint = tf.train.Checkpoint(**tf_state)
127
    self.tf_checkpoint_manager = tf.train.CheckpointManager(
128
        self.tf_checkpoint,
129
        base_directory,
130
        max_to_keep=max_to_keep,
131
        checkpoint_name=checkpoint_name)
132

133
  def get_latest_checkpoint_to_restore_from(self):
134
    """Returns the latest checkpoint to restore from.
135

136
    In the current implementation, this method simply returns the attribute
137
    `latest_checkpoint`.
138

139
    Subclasses can override this method to provide an alternative checkpoint to
140
    restore from, for example for synchronization across multiple checkpoint
141
    directories.
142
    """
143
    return self.latest_checkpoint
144

145
  @property
146
  def latest_checkpoint(self):
147
    """Latest checkpoint, see `tf.train.CheckpointManager.latest_checkpoint`.
148

149
    Returns:
150
      A string to the latest checkpoint. Note that this string is path-like but
151
      it does not really describe a file, but rather a set of files that are
152
      constructed from this string, by appending different file extensions. The
153
      returned value is `None` if there is no previously stored checkpoint in
154
      `base_directory` specified to `__init__()`.
155
    """
156
    return self.tf_checkpoint_manager.latest_checkpoint
157

158
  @property
159
  def latest_checkpoint_flax(self):
160
    """Path of the latest serialized `state`.
161

162
    Returns:
163
      Path of the file containing the serialized Flax state. The returned value
164
      is `None` if there is no previously stored checkpoint in `base_directory`
165
      specified to `__init__()`.
166
    """
167
    if self.latest_checkpoint is None:
168
      return None
169
    return self._flax_path(self.latest_checkpoint)
170

171
  def _flax_path(self, checkpoint):
172
    return "{}.flax".format(checkpoint)
173

174
  def _next_checkpoint(self, checkpoint):
175
    if checkpoint is None:
176
      return str(
177
          CheckpointInfo.initialize(self.base_directory, self.checkpoint_name))
178
    return str(CheckpointInfo.from_path(checkpoint).increment())
179

180
  def save(self, state):
181
    """Saves a new checkpoints in the directory.
182

183
    Args:
184
      state: Flax checkpoint to be stored.
185

186
    Returns:
187
      The checkpoint identifier ({base_directory}/ckpt-{number}).
188
    """
189
    next_checkpoint = self._next_checkpoint(self.latest_checkpoint)
190
    flax_path = self._flax_path(next_checkpoint)
191
    if not tf.io.gfile.exists(self.base_directory):
192
      tf.io.gfile.makedirs(self.base_directory)
193
    with tf.io.gfile.GFile(flax_path, "wb") as f:
194
      f.write(flax.serialization.to_bytes(state))
195
    checkpoints = set(self.tf_checkpoint_manager.checkpoints)
196
    # Write Tensorflow data last. This way Tensorflow checkpoint generation
197
    # logic will make sure to only commit checkpoints if they complete
198
    # successfully. A previously written `flax_path` would then simply be
199
    # overwritten next time.
200
    self.tf_checkpoint_manager.save()
201
    for checkpoint in checkpoints.difference(
202
        self.tf_checkpoint_manager.checkpoints):
203
      tf.io.gfile.remove(self._flax_path(checkpoint))
204
    if next_checkpoint != self.latest_checkpoint:
205
      raise AssertionError(  # pylint: disable=g-doc-exception
206
          "Expected next_checkpoint to match latest_checkpoint: "
207
          f"{next_checkpoint} != {self.latest_checkpoint}")
208
    return self.latest_checkpoint  # pytype: disable=bad-return-type  # always-use-return-annotations
209

210
  def restore_or_initialize(self, state):
211
    """Restores from the latest checkpoint, or creates a first checkpoint.
212

213
    Args:
214
      state : A flax checkpoint to be stored or to serve as a template. If the
215
        checkoint is restored (and not initialized), then the fields of `state`
216
        must match the data previously stored.
217

218
    Returns:
219
      The restored `state` object. Note that all TensorFlow `Trackable`s in
220
      `tf_state` (see `__init__()`) are also updated.
221
    """
222
    latest_checkpoint = self.get_latest_checkpoint_to_restore_from()
223
    if not latest_checkpoint:
224
      logging.info("No previous checkpoint found.")
225
      # Only save one copy for host 0.
226
      if jax.host_id() == 0:
227
        self.save(state)
228
      return state
229
    self.tf_checkpoint.restore(latest_checkpoint)
230
    flax_path = self._flax_path(latest_checkpoint)
231
    with tf.io.gfile.GFile(flax_path, "rb") as f:
232
      state = flax.serialization.from_bytes(state, f.read())
233
    return state
234

235
  def restore(self, state):
236
    """Restores from the latest checkpoint.
237

238
    Similar to `restore_or_initialize()`, but raises a `FileNotFoundError` if
239
    there is no checkpoint.
240

241
    Args:
242
      state : A flax checkpoint to be stored or to serve as a template. If the
243
        checkoint is restored (and not initialized), then the fields of `state`
244
        must match the data previously stored.
245

246
    Returns:
247
      The restored `state` object. Note that all TensorFlow `Trackable`s in
248
      `tf_state` (see `__init__()`) are also updated.
249

250
    Raises:
251
      FileNotFoundError: If there is no checkpoint to restore.
252
    """
253
    latest_checkpoint = self.get_latest_checkpoint_to_restore_from()
254
    if not latest_checkpoint:
255
      raise FileNotFoundError(f"No checkpoint found at {self.base_directory}")
256
    return self.restore_or_initialize(state)
257

258

259
def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format=None):
260
  """Make a grid of images and Save it into an image file.
261

262
  Args:
263
    ndarray (array_like): 4D mini-batch images of shape (B x H x W x C).
264
    fp: A filename(string) or file object.
265
    nrow (int, optional): Number of images displayed in each row of the grid.
266
      The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
267
    padding (int, optional): amount of padding. Default: ``2``.
268
    pad_value (float, optional): Value for the padded pixels. Default: ``0``.
269
    format(Optional):  If omitted, the format to use is determined from the
270
      filename extension. If a file object was used instead of a filename, this
271
      parameter should always be used.
272
  """
273
  if not (isinstance(ndarray, jnp.ndarray) or
274
          (isinstance(ndarray, list) and
275
           all(isinstance(t, jnp.ndarray) for t in ndarray))):
276
    raise TypeError("array_like of tensors expected, got {}".format(
277
        type(ndarray)))
278

279
  ndarray = jnp.asarray(ndarray)
280

281
  if ndarray.ndim == 4 and ndarray.shape[-1] == 1:  # single-channel images
282
    ndarray = jnp.concatenate((ndarray, ndarray, ndarray), -1)
283

284
  # make the mini-batch of images into a grid
285
  nmaps = ndarray.shape[0]
286
  xmaps = min(nrow, nmaps)
287
  ymaps = int(math.ceil(float(nmaps) / xmaps))
288
  height, width = int(ndarray.shape[1] + padding), int(ndarray.shape[2] +
289
                                                       padding)
290
  num_channels = ndarray.shape[3]
291
  grid = jnp.full(
292
      (height * ymaps + padding, width * xmaps + padding, num_channels),
293
      pad_value).astype(jnp.float32)
294
  k = 0
295
  for y in range(ymaps):
296
    for x in range(xmaps):
297
      if k >= nmaps:
298
        break
299
      grid = grid.at[y * height + padding:(y + 1) * height,
300
                     x * width + padding:(x + 1) * width].set(ndarray[k])
301
      k = k + 1
302

303
  # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
304
  ndarr = jnp.clip(grid * 255.0 + 0.5, 0, 255).astype(jnp.uint8)
305
  im = Image.fromarray(ndarr.copy())
306
  im.save(fp, format=format)
307

308

309
def flatten_dict(config):
310
  """Flatten a hierarchical dict to a simple dict."""
311
  new_dict = {}
312
  for key, value in config.items():
313
    if isinstance(value, dict):
314
      sub_dict = flatten_dict(value)
315
      for subkey, subvalue in sub_dict.items():
316
        new_dict[key + "/" + subkey] = subvalue
317
    elif isinstance(value, tuple):
318
      new_dict[key] = str(value)
319
    else:
320
      new_dict[key] = value
321
  return new_dict
322

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.