google-research
270 строк · 8.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Training and eval worker utilities."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import os
24import time
25
26from . import logging_utils
27
28from absl import logging
29import numpy as np
30import tensorflow.compat.v1 as tf
31
32
33class BaseModel(object):
34
35def train_fn(self, x_bhwc):
36raise NotImplementedError
37
38def eval_fn(self, x_bhwc):
39raise NotImplementedError
40
41def samples_fn(self, x_bhwc):
42raise NotImplementedError
43
44@property
45def trainable_variables(self):
46raise NotImplementedError
47
48@property
49def ema(self):
50raise NotImplementedError
51
52
53def _make_ema_model(orig_model, model_constructor):
54
55# Model with EMA parameters
56if orig_model.ema is None:
57return None
58
59def _to_original_variable_name(name):
60# map to the original variable name
61parts = name.split('/')
62assert parts[0] == 'ema_scope'
63return '/'.join(parts[1:])
64
65def _ema_getter(getter, name, *args, **kwargs):
66v = getter(_to_original_variable_name(name), *args, **kwargs)
67v = orig_model.ema.average(v)
68if v is None:
69raise RuntimeError('invalid EMA variable name {} -> {}'.format(
70name, _to_original_variable_name(name)))
71return v
72
73with tf.variable_scope(
74tf.get_variable_scope(), custom_getter=_ema_getter, reuse=True):
75with tf.name_scope('ema_scope'):
76return model_constructor()
77
78
79def run_eval(
80model_constructor,
81logdir,
82total_bs,
83master,
84input_fn,
85dataset_size):
86
87worker = EvalWorker(
88master=master,
89model_constructor=model_constructor,
90total_bs=total_bs,
91input_fn=input_fn)
92worker.run(logdir=logdir, once=True)
93
94
95class EvalWorker(object):
96
97def __init__(self, master, model_constructor, total_bs, input_fn):
98self.strategy = tf.distribute.MirroredStrategy()
99
100self.num_cores = self.strategy.num_replicas_in_sync
101assert total_bs % self.num_cores == 0
102self.total_bs = total_bs
103self.local_bs = total_bs // self.num_cores
104logging.info('num cores: {}'.format(self.num_cores))
105logging.info('total batch size: {}'.format(self.total_bs))
106logging.info('local batch size: {}'.format(self.local_bs))
107
108with self.strategy.scope():
109# Dataset iterator
110dataset = input_fn(params={'batch_size': self.total_bs})
111self.eval_iterator = self.strategy.experimental_distribute_dataset(
112dataset).make_initializable_iterator()
113eval_iterator_next = next(self.eval_iterator)
114
115# Model
116self.model = model_constructor()
117# Model with EMA parameters
118self.ema_model = _make_ema_model(self.model, model_constructor)
119
120# Global step
121self.global_step = tf.train.get_global_step()
122assert self.global_step is not None, 'global step not created'
123
124# Eval/samples graphs
125self.eval_outputs = self._distributed(
126self.model.eval_fn, args=(eval_iterator_next,), reduction='mean')
127self.samples_outputs = self._distributed(
128self.model.samples_fn, args=(eval_iterator_next,), reduction='concat')
129# EMA versions of the above
130if self.ema_model is not None:
131self.ema_eval_outputs = self._distributed(
132self.ema_model.eval_fn,
133args=(eval_iterator_next,),
134reduction='mean')
135self.ema_samples_outputs = self._distributed(
136self.ema_model.samples_fn,
137args=(eval_iterator_next,),
138reduction='concat')
139
140def _distributed(self, model_fn, args, reduction):
141"""Sharded computation."""
142
143def model_wrapper(inputs_):
144return model_fn(inputs_['image'])
145
146out = self.strategy.run(model_wrapper, args=args)
147assert isinstance(out, dict)
148
149if reduction == 'mean':
150out = {
151k: tf.reduce_mean(self.strategy.reduce('mean', v))
152for k, v in out.items()
153}
154assert all(v.shape == [] for v in out.values()) # pylint: disable=g-explicit-bool-comparison
155elif reduction == 'concat':
156out = {
157k: tf.concat(self.strategy.experimental_local_results(v), axis=0)
158for k, v in out.items()
159}
160assert all(v.shape[0] == self.total_bs for v in out.values())
161else:
162raise NotImplementedError(reduction)
163
164return out
165
166def _make_session(self):
167config = tf.ConfigProto()
168config.allow_soft_placement = True
169logging.info('making session...')
170return tf.Session(config=config)
171
172def _run_eval(self, sess, ema):
173logging.info('eval pass...')
174sess.run(self.eval_iterator.initializer)
175all_loss_lists = collections.defaultdict(list)
176run_times = []
177try:
178while True:
179# Log progress
180if run_times and len(run_times) % 100 == 0:
181num_batches_seen = len(list(all_loss_lists.values())[0])
182logging.info(
183'eval examples_so_far={} time_per_batch={:.5f} {}'.format(
184num_batches_seen * self.total_bs,
185np.mean(run_times[1:]),
186{k: np.mean(l) for k, l in all_loss_lists.items()}))
187tstart = time.time()
188results = sess.run(self.ema_eval_outputs if ema else self.eval_outputs)
189run_times.append(time.time() - tstart)
190for k, v in results.items():
191all_loss_lists[k].append(v)
192except tf.errors.OutOfRangeError:
193pass
194num_batches_seen = len(list(all_loss_lists.values())[0])
195logging.info('eval pass done ({} batches, {} examples)'.format(
196num_batches_seen, num_batches_seen * self.total_bs))
197results = {k: np.mean(l) for k, l in all_loss_lists.items()}
198logging.info('final eval results: {}'.format(results))
199return results
200
201def _run_sampling(self, sess, ema):
202sess.run(self.eval_iterator.initializer)
203logging.info('sampling...')
204samples = sess.run(
205self.ema_samples_outputs if ema else self.samples_outputs)
206logging.info('sampling done')
207return samples
208
209def _write_eval_and_samples(self, sess, log, curr_step, prefix, ema):
210# Samples
211samples_dict = self._run_sampling(sess, ema=ema)
212for k, v in samples_dict.items():
213assert len(v.shape) == 4 and v.shape[0] == self.total_bs
214log.summary_writer.images(
215'{}/{}'.format(prefix, k),
216np.clip(v, 0, 255).astype('uint8'),
217step=curr_step)
218log.summary_writer.flush()
219
220# Eval
221eval_losses = self._run_eval(sess, ema=ema)
222for k, v in eval_losses.items():
223log.write(prefix, [{k: v}], step=curr_step)
224
225def run(self, logdir, once, skip_non_ema_pass=True):
226"""Runs the eval/sampling worker loop.
227
228Args:
229logdir: directory to read checkpoints from
230once: if True, writes results to a temporary directory (not to logdir),
231and exits after evaluating one checkpoint.
232"""
233if once:
234eval_logdir = os.path.join(logdir, 'eval_once_{}'.format(time.time()))
235else:
236eval_logdir = logdir
237logging.info('Writing eval data to: {}'.format(eval_logdir))
238eval_log = logging_utils.Log(eval_logdir, write_graph=False)
239
240with self._make_session() as sess:
241# Checkpoint loading
242logging.info('making saver')
243saver = tf.train.Saver()
244
245for ckpt in tf.train.checkpoints_iterator(logdir):
246logging.info('restoring params...')
247saver.restore(sess, ckpt)
248global_step_val = sess.run(self.global_step)
249logging.info('restored global step: {}'.format(global_step_val))
250
251if not skip_non_ema_pass:
252logging.info('non-ema pass')
253self._write_eval_and_samples(
254sess,
255log=eval_log,
256curr_step=global_step_val,
257prefix='eval',
258ema=False)
259
260if self.ema_model is not None:
261logging.info('ema pass')
262self._write_eval_and_samples(
263sess,
264log=eval_log,
265curr_step=global_step_val,
266prefix='eval_ema',
267ema=True)
268
269if once:
270break
271