google-research
321 строка · 11.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: skip-file
17"""Utility code for generating and saving image grids and checkpointing.
18
19The save_image code is copied from
20https://github.com/google/flax/blob/master/examples/vae/utils.py,
21which is a JAX equivalent to the same function in TorchVision
22(https://github.com/pytorch/vision/blob/master/torchvision/utils.py)
23"""
24
25import collections26import math27import re28from typing import Any, Dict, Optional, TypeVar29
30from absl import logging31import flax32import jax33import jax.numpy as jnp34from PIL import Image35import tensorflow as tf36
37T = TypeVar("T")38
39
40def load_state_dict(filepath, state):41with tf.io.gfile.GFile(filepath, "rb") as f:42state = flax.serialization.from_bytes(state, f.read())43return state44
45
46class CheckpointInfo(47collections.namedtuple("CheckpointInfo", ("prefix", "number"))):48"""Helper class to parse a TensorFlow checkpoint path."""49
50CHECKPOINT_REGEX = r"^(?P<prefix>.*)-(?P<number>\d+)"51
52@classmethod53def initialize(cls, base_directory, checkpoint_name):54"""Creates a first CheckpointInfo (number=1)."""55return cls(f"{base_directory}/{checkpoint_name}", 1)56
57@classmethod58def from_path(cls, checkpoint):59"""Parses a checkpoint.60
61Args:
62checkpoint: A checkpoint prefix, as can be found in the
63`.latest_checkpoint` property of a `tf.train.CheckpointManager`.
64
65Returns:
66An instance of `CheckpointInfo` that represents `checkpoint`.
67"""
68m = re.match(cls.CHECKPOINT_REGEX, checkpoint)69if m is None:70RuntimeError(f"Invalid checkpoint format: {checkpoint}")71d = m.groupdict() # pytype: disable=attribute-error72return cls(d["prefix"], int(d["number"]))73
74def increment(self):75"""Returns a new CheckpointInfo with `number` increased by one."""76return CheckpointInfo(self.prefix, self.number + 1)77
78def __str__(self):79"""Does the opposite of `.from_path()`."""80return f"{self.prefix}-{self.number}"81
82
83class Checkpoint:84"""A utility class for storing and loading TF2/Flax checkpoints.85
86
87Both the state of a `tf.data.Dataset` iterator and a `flax.struct.dataclass`
88are stored on disk in the following files:
89
90- {directory}/checkpoint
91- {directory}/ckpt-{number}.index
92- {directory}/ckpt-{number}.data@*
93- {directory}/ckpt-{number}.flax
94
95Where {number} starts at 1 is then incremented by 1 for every new checkpoint.
96The last file is the `flax.struct.dataclass`, serialized in Messagepack
97format. The other files are explained in more detail in the Tensorflow
98documentation:
99
100https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint
101"""
102
103def __init__(self,104base_directory,105tf_state = None,106*,107max_to_keep = None,108checkpoint_name = "ckpt"):109"""Initializes a Checkpoint with a dictionary of TensorFlow Trackables.110
111Args:
112base_directory: Directory under which the checkpoints will be stored. Use
113a different base_directory in every task.
114tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
115example a dataset iterator.
116max_to_keep: Number of checkpoints to keep in the directory. If there are
117more checkpoints than specified by this number, then the oldest
118checkpoints are removed.
119checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
120"""
121if tf_state is None:122tf_state = dict()123self.base_directory = base_directory124self.max_to_keep = max_to_keep125self.checkpoint_name = checkpoint_name126self.tf_checkpoint = tf.train.Checkpoint(**tf_state)127self.tf_checkpoint_manager = tf.train.CheckpointManager(128self.tf_checkpoint,129base_directory,130max_to_keep=max_to_keep,131checkpoint_name=checkpoint_name)132
133def get_latest_checkpoint_to_restore_from(self):134"""Returns the latest checkpoint to restore from.135
136In the current implementation, this method simply returns the attribute
137`latest_checkpoint`.
138
139Subclasses can override this method to provide an alternative checkpoint to
140restore from, for example for synchronization across multiple checkpoint
141directories.
142"""
143return self.latest_checkpoint144
145@property146def latest_checkpoint(self):147"""Latest checkpoint, see `tf.train.CheckpointManager.latest_checkpoint`.148
149Returns:
150A string to the latest checkpoint. Note that this string is path-like but
151it does not really describe a file, but rather a set of files that are
152constructed from this string, by appending different file extensions. The
153returned value is `None` if there is no previously stored checkpoint in
154`base_directory` specified to `__init__()`.
155"""
156return self.tf_checkpoint_manager.latest_checkpoint157
158@property159def latest_checkpoint_flax(self):160"""Path of the latest serialized `state`.161
162Returns:
163Path of the file containing the serialized Flax state. The returned value
164is `None` if there is no previously stored checkpoint in `base_directory`
165specified to `__init__()`.
166"""
167if self.latest_checkpoint is None:168return None169return self._flax_path(self.latest_checkpoint)170
171def _flax_path(self, checkpoint):172return "{}.flax".format(checkpoint)173
174def _next_checkpoint(self, checkpoint):175if checkpoint is None:176return str(177CheckpointInfo.initialize(self.base_directory, self.checkpoint_name))178return str(CheckpointInfo.from_path(checkpoint).increment())179
180def save(self, state):181"""Saves a new checkpoints in the directory.182
183Args:
184state: Flax checkpoint to be stored.
185
186Returns:
187The checkpoint identifier ({base_directory}/ckpt-{number}).
188"""
189next_checkpoint = self._next_checkpoint(self.latest_checkpoint)190flax_path = self._flax_path(next_checkpoint)191if not tf.io.gfile.exists(self.base_directory):192tf.io.gfile.makedirs(self.base_directory)193with tf.io.gfile.GFile(flax_path, "wb") as f:194f.write(flax.serialization.to_bytes(state))195checkpoints = set(self.tf_checkpoint_manager.checkpoints)196# Write Tensorflow data last. This way Tensorflow checkpoint generation197# logic will make sure to only commit checkpoints if they complete198# successfully. A previously written `flax_path` would then simply be199# overwritten next time.200self.tf_checkpoint_manager.save()201for checkpoint in checkpoints.difference(202self.tf_checkpoint_manager.checkpoints):203tf.io.gfile.remove(self._flax_path(checkpoint))204if next_checkpoint != self.latest_checkpoint:205raise AssertionError( # pylint: disable=g-doc-exception206"Expected next_checkpoint to match latest_checkpoint: "207f"{next_checkpoint} != {self.latest_checkpoint}")208return self.latest_checkpoint # pytype: disable=bad-return-type # always-use-return-annotations209
210def restore_or_initialize(self, state):211"""Restores from the latest checkpoint, or creates a first checkpoint.212
213Args:
214state : A flax checkpoint to be stored or to serve as a template. If the
215checkoint is restored (and not initialized), then the fields of `state`
216must match the data previously stored.
217
218Returns:
219The restored `state` object. Note that all TensorFlow `Trackable`s in
220`tf_state` (see `__init__()`) are also updated.
221"""
222latest_checkpoint = self.get_latest_checkpoint_to_restore_from()223if not latest_checkpoint:224logging.info("No previous checkpoint found.")225# Only save one copy for host 0.226if jax.host_id() == 0:227self.save(state)228return state229self.tf_checkpoint.restore(latest_checkpoint)230flax_path = self._flax_path(latest_checkpoint)231with tf.io.gfile.GFile(flax_path, "rb") as f:232state = flax.serialization.from_bytes(state, f.read())233return state234
235def restore(self, state):236"""Restores from the latest checkpoint.237
238Similar to `restore_or_initialize()`, but raises a `FileNotFoundError` if
239there is no checkpoint.
240
241Args:
242state : A flax checkpoint to be stored or to serve as a template. If the
243checkoint is restored (and not initialized), then the fields of `state`
244must match the data previously stored.
245
246Returns:
247The restored `state` object. Note that all TensorFlow `Trackable`s in
248`tf_state` (see `__init__()`) are also updated.
249
250Raises:
251FileNotFoundError: If there is no checkpoint to restore.
252"""
253latest_checkpoint = self.get_latest_checkpoint_to_restore_from()254if not latest_checkpoint:255raise FileNotFoundError(f"No checkpoint found at {self.base_directory}")256return self.restore_or_initialize(state)257
258
259def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format=None):260"""Make a grid of images and Save it into an image file.261
262Args:
263ndarray (array_like): 4D mini-batch images of shape (B x H x W x C).
264fp: A filename(string) or file object.
265nrow (int, optional): Number of images displayed in each row of the grid.
266The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
267padding (int, optional): amount of padding. Default: ``2``.
268pad_value (float, optional): Value for the padded pixels. Default: ``0``.
269format(Optional): If omitted, the format to use is determined from the
270filename extension. If a file object was used instead of a filename, this
271parameter should always be used.
272"""
273if not (isinstance(ndarray, jnp.ndarray) or274(isinstance(ndarray, list) and275all(isinstance(t, jnp.ndarray) for t in ndarray))):276raise TypeError("array_like of tensors expected, got {}".format(277type(ndarray)))278
279ndarray = jnp.asarray(ndarray)280
281if ndarray.ndim == 4 and ndarray.shape[-1] == 1: # single-channel images282ndarray = jnp.concatenate((ndarray, ndarray, ndarray), -1)283
284# make the mini-batch of images into a grid285nmaps = ndarray.shape[0]286xmaps = min(nrow, nmaps)287ymaps = int(math.ceil(float(nmaps) / xmaps))288height, width = int(ndarray.shape[1] + padding), int(ndarray.shape[2] +289padding)290num_channels = ndarray.shape[3]291grid = jnp.full(292(height * ymaps + padding, width * xmaps + padding, num_channels),293pad_value).astype(jnp.float32)294k = 0295for y in range(ymaps):296for x in range(xmaps):297if k >= nmaps:298break299grid = grid.at[y * height + padding:(y + 1) * height,300x * width + padding:(x + 1) * width].set(ndarray[k])301k = k + 1302
303# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer304ndarr = jnp.clip(grid * 255.0 + 0.5, 0, 255).astype(jnp.uint8)305im = Image.fromarray(ndarr.copy())306im.save(fp, format=format)307
308
309def flatten_dict(config):310"""Flatten a hierarchical dict to a simple dict."""311new_dict = {}312for key, value in config.items():313if isinstance(value, dict):314sub_dict = flatten_dict(value)315for subkey, subvalue in sub_dict.items():316new_dict[key + "/" + subkey] = subvalue317elif isinstance(value, tuple):318new_dict[key] = str(value)319else:320new_dict[key] = value321return new_dict322