google-research

Форк
0
/
worker_util.py 
270 строк · 8.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Training and eval worker utilities."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23
import os
24
import time
25

26
from . import logging_utils
27

28
from absl import logging
29
import numpy as np
30
import tensorflow.compat.v1 as tf
31

32

33
class BaseModel(object):
34

35
  def train_fn(self, x_bhwc):
36
    raise NotImplementedError
37

38
  def eval_fn(self, x_bhwc):
39
    raise NotImplementedError
40

41
  def samples_fn(self, x_bhwc):
42
    raise NotImplementedError
43

44
  @property
45
  def trainable_variables(self):
46
    raise NotImplementedError
47

48
  @property
49
  def ema(self):
50
    raise NotImplementedError
51

52

53
def _make_ema_model(orig_model, model_constructor):
54

55
  # Model with EMA parameters
56
  if orig_model.ema is None:
57
    return None
58

59
  def _to_original_variable_name(name):
60
    # map to the original variable name
61
    parts = name.split('/')
62
    assert parts[0] == 'ema_scope'
63
    return '/'.join(parts[1:])
64

65
  def _ema_getter(getter, name, *args, **kwargs):
66
    v = getter(_to_original_variable_name(name), *args, **kwargs)
67
    v = orig_model.ema.average(v)
68
    if v is None:
69
      raise RuntimeError('invalid EMA variable name {} -> {}'.format(
70
          name, _to_original_variable_name(name)))
71
    return v
72

73
  with tf.variable_scope(
74
      tf.get_variable_scope(), custom_getter=_ema_getter, reuse=True):
75
    with tf.name_scope('ema_scope'):
76
      return model_constructor()
77

78

79
def run_eval(
80
    model_constructor,
81
    logdir,
82
    total_bs,
83
    master,
84
    input_fn,
85
    dataset_size):
86

87
  worker = EvalWorker(
88
      master=master,
89
      model_constructor=model_constructor,
90
      total_bs=total_bs,
91
      input_fn=input_fn)
92
  worker.run(logdir=logdir, once=True)
93

94

95
class EvalWorker(object):
96

97
  def __init__(self, master, model_constructor, total_bs, input_fn):
98
    self.strategy = tf.distribute.MirroredStrategy()
99

100
    self.num_cores = self.strategy.num_replicas_in_sync
101
    assert total_bs % self.num_cores == 0
102
    self.total_bs = total_bs
103
    self.local_bs = total_bs // self.num_cores
104
    logging.info('num cores: {}'.format(self.num_cores))
105
    logging.info('total batch size: {}'.format(self.total_bs))
106
    logging.info('local batch size: {}'.format(self.local_bs))
107

108
    with self.strategy.scope():
109
      # Dataset iterator
110
      dataset = input_fn(params={'batch_size': self.total_bs})
111
      self.eval_iterator = self.strategy.experimental_distribute_dataset(
112
          dataset).make_initializable_iterator()
113
      eval_iterator_next = next(self.eval_iterator)
114

115
      # Model
116
      self.model = model_constructor()
117
      # Model with EMA parameters
118
      self.ema_model = _make_ema_model(self.model, model_constructor)
119

120
      # Global step
121
      self.global_step = tf.train.get_global_step()
122
      assert self.global_step is not None, 'global step not created'
123

124
      # Eval/samples graphs
125
      self.eval_outputs = self._distributed(
126
          self.model.eval_fn, args=(eval_iterator_next,), reduction='mean')
127
      self.samples_outputs = self._distributed(
128
          self.model.samples_fn, args=(eval_iterator_next,), reduction='concat')
129
      # EMA versions of the above
130
      if self.ema_model is not None:
131
        self.ema_eval_outputs = self._distributed(
132
            self.ema_model.eval_fn,
133
            args=(eval_iterator_next,),
134
            reduction='mean')
135
        self.ema_samples_outputs = self._distributed(
136
            self.ema_model.samples_fn,
137
            args=(eval_iterator_next,),
138
            reduction='concat')
139

140
  def _distributed(self, model_fn, args, reduction):
141
    """Sharded computation."""
142

143
    def model_wrapper(inputs_):
144
      return model_fn(inputs_['image'])
145

146
    out = self.strategy.run(model_wrapper, args=args)
147
    assert isinstance(out, dict)
148

149
    if reduction == 'mean':
150
      out = {
151
          k: tf.reduce_mean(self.strategy.reduce('mean', v))
152
          for k, v in out.items()
153
      }
154
      assert all(v.shape == [] for v in out.values())  # pylint: disable=g-explicit-bool-comparison
155
    elif reduction == 'concat':
156
      out = {
157
          k: tf.concat(self.strategy.experimental_local_results(v), axis=0)
158
          for k, v in out.items()
159
      }
160
      assert all(v.shape[0] == self.total_bs for v in out.values())
161
    else:
162
      raise NotImplementedError(reduction)
163

164
    return out
165

166
  def _make_session(self):
167
    config = tf.ConfigProto()
168
    config.allow_soft_placement = True
169
    logging.info('making session...')
170
    return tf.Session(config=config)
171

172
  def _run_eval(self, sess, ema):
173
    logging.info('eval pass...')
174
    sess.run(self.eval_iterator.initializer)
175
    all_loss_lists = collections.defaultdict(list)
176
    run_times = []
177
    try:
178
      while True:
179
        # Log progress
180
        if run_times and len(run_times) % 100 == 0:
181
          num_batches_seen = len(list(all_loss_lists.values())[0])
182
          logging.info(
183
              'eval examples_so_far={} time_per_batch={:.5f} {}'.format(
184
                  num_batches_seen * self.total_bs,
185
                  np.mean(run_times[1:]),
186
                  {k: np.mean(l) for k, l in all_loss_lists.items()}))
187
        tstart = time.time()
188
        results = sess.run(self.ema_eval_outputs if ema else self.eval_outputs)
189
        run_times.append(time.time() - tstart)
190
        for k, v in results.items():
191
          all_loss_lists[k].append(v)
192
    except tf.errors.OutOfRangeError:
193
      pass
194
    num_batches_seen = len(list(all_loss_lists.values())[0])
195
    logging.info('eval pass done ({} batches, {} examples)'.format(
196
        num_batches_seen, num_batches_seen * self.total_bs))
197
    results = {k: np.mean(l) for k, l in all_loss_lists.items()}
198
    logging.info('final eval results: {}'.format(results))
199
    return results
200

201
  def _run_sampling(self, sess, ema):
202
    sess.run(self.eval_iterator.initializer)
203
    logging.info('sampling...')
204
    samples = sess.run(
205
        self.ema_samples_outputs if ema else self.samples_outputs)
206
    logging.info('sampling done')
207
    return samples
208

209
  def _write_eval_and_samples(self, sess, log, curr_step, prefix, ema):
210
    # Samples
211
    samples_dict = self._run_sampling(sess, ema=ema)
212
    for k, v in samples_dict.items():
213
      assert len(v.shape) == 4 and v.shape[0] == self.total_bs
214
      log.summary_writer.images(
215
          '{}/{}'.format(prefix, k),
216
          np.clip(v, 0, 255).astype('uint8'),
217
          step=curr_step)
218
    log.summary_writer.flush()
219

220
    # Eval
221
    eval_losses = self._run_eval(sess, ema=ema)
222
    for k, v in eval_losses.items():
223
      log.write(prefix, [{k: v}], step=curr_step)
224

225
  def run(self, logdir, once, skip_non_ema_pass=True):
226
    """Runs the eval/sampling worker loop.
227

228
    Args:
229
      logdir: directory to read checkpoints from
230
      once: if True, writes results to a temporary directory (not to logdir),
231
        and exits after evaluating one checkpoint.
232
    """
233
    if once:
234
      eval_logdir = os.path.join(logdir, 'eval_once_{}'.format(time.time()))
235
    else:
236
      eval_logdir = logdir
237
    logging.info('Writing eval data to: {}'.format(eval_logdir))
238
    eval_log = logging_utils.Log(eval_logdir, write_graph=False)
239

240
    with self._make_session() as sess:
241
      # Checkpoint loading
242
      logging.info('making saver')
243
      saver = tf.train.Saver()
244

245
      for ckpt in tf.train.checkpoints_iterator(logdir):
246
        logging.info('restoring params...')
247
        saver.restore(sess, ckpt)
248
        global_step_val = sess.run(self.global_step)
249
        logging.info('restored global step: {}'.format(global_step_val))
250

251
        if not skip_non_ema_pass:
252
          logging.info('non-ema pass')
253
          self._write_eval_and_samples(
254
              sess,
255
              log=eval_log,
256
              curr_step=global_step_val,
257
              prefix='eval',
258
              ema=False)
259

260
        if self.ema_model is not None:
261
          logging.info('ema pass')
262
          self._write_eval_and_samples(
263
              sess,
264
              log=eval_log,
265
              curr_step=global_step_val,
266
              prefix='eval_ema',
267
              ema=True)
268

269
        if once:
270
          break
271

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.