google-research

Форк
0
199 строк · 6.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Keypose estimator training / eval functions.
17

18
Adapted from 'Discovery of Latent 3D Keypoints via End-to-end
19
Geometric Reasoning' keypoint network.
20

21
Given a 2D image and viewpoint, predict a set of 3D keypoints that
22
match the target examples.
23

24
Can be instance or class specific, depending on training set.
25

26
Typical invocation:
27
$ python3 -m keypose.trainer.py configs/bottle_0_t5 /tmp/model
28
"""
29

30
import os
31
import sys
32

33
from tensorflow import estimator as tf_estimator
34
from tensorflow import keras
35

36
from keypose import estimator as est
37
from keypose import inputs as inp
38
from keypose import utils
39

40

41
def train_and_eval(params,
42
                   model_fn,
43
                   input_fn,
44
                   keep_checkpoint_every_n_hours=0.5,
45
                   save_checkpoints_secs=100,
46
                   eval_steps=0,
47
                   eval_start_delay_secs=10,
48
                   eval_throttle_secs=100,
49
                   save_summary_steps=50):
50
  """Trains and evaluates our model.
51

52
  Supports local and distributed training.
53

54
  Args:
55
    params: ConfigParams class with model training and network parameters.
56
    model_fn: A func with prototype model_fn(features, labels, mode, hparams).
57
    input_fn: A input function for the tf.estimator.Estimator.
58
    keep_checkpoint_every_n_hours: Number of hours between each checkpoint to be
59
      saved.
60
    save_checkpoints_secs: Save checkpoints every this many seconds.
61
    eval_steps: Number of steps to evaluate model; 0 for one epoch.
62
    eval_start_delay_secs: Start evaluating after waiting for this many seconds.
63
    eval_throttle_secs: Do not re-evaluate unless the last evaluation was
64
      started at least this many seconds ago
65
    save_summary_steps: Save summaries every this many steps.
66
  """
67

68
  mparams = params.model_params
69

70
  run_config = tf_estimator.RunConfig(
71
      keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
72
      save_checkpoints_secs=save_checkpoints_secs,
73
      save_summary_steps=save_summary_steps)
74

75
  if run_config.model_dir:
76
    params.model_dir = run_config.model_dir
77
  print('\nCreating estimator with model dir %s' % params.model_dir)
78
  estimator = tf_estimator.Estimator(
79
      model_fn=model_fn,
80
      model_dir=params.model_dir,
81
      config=run_config,
82
      params=params)
83

84
  print('\nCreating train_spec')
85
  train_spec = tf_estimator.TrainSpec(
86
      input_fn=input_fn(params, split='train'), max_steps=params.steps)
87

88
  print('\nCreating eval_spec')
89

90
  def serving_input_receiver_fn():
91
    """Serving input_fn that builds features from placeholders.
92

93
    Returns:
94
      A tf.estimator.export.ServingInputReceiver.
95
    """
96
    modelx = mparams.modelx
97
    modely = mparams.modely
98
    offsets = keras.Input(shape=(3,), name='offsets', dtype='float32')
99
    hom = keras.Input(shape=(3, 3), name='hom', dtype='float32')
100
    to_world = keras.Input(shape=(4, 4), name='to_world_L', dtype='float32')
101
    img_l = keras.Input(
102
        shape=(modely, modelx, 3), name='img_L', dtype='float32')
103
    img_r = keras.Input(
104
        shape=(modely, modelx, 3), name='img_R', dtype='float32')
105
    features = {
106
        'img_L': img_l,
107
        'img_R': img_r,
108
        'to_world_L': to_world,
109
        'offsets': offsets,
110
        'hom': hom
111
    }
112
    return tf_estimator.export.build_raw_serving_input_receiver_fn(features)
113

114
  class SaveModel(tf_estimator.SessionRunHook):
115
    """Saves a model in SavedModel format."""
116

117
    def __init__(self, estimator, output_dir):
118
      self.output_dir = output_dir
119
      self.estimator = estimator
120
      self.save_num = 0
121

122
    def begin(self):
123
      ckpt = self.estimator.latest_checkpoint()
124
      print('Latest checkpoint in hook:', ckpt)
125
      ckpt_num_str = ckpt.split('.ckpt-')[1]
126
      if (int(ckpt_num_str) - self.save_num) > 4000:
127
        fname = os.path.join(self.output_dir, 'saved_model-' + ckpt_num_str)
128
        print('**** Saving model in train hook: %s' % fname)
129
        self.estimator.export_saved_model(fname, serving_input_receiver_fn())
130
        self.save_num = int(ckpt_num_str)
131

132
  saver_hook = SaveModel(estimator, params.model_dir)
133

134
  if eval_steps == 0:
135
    eval_steps = None
136
  eval_spec = tf_estimator.EvalSpec(
137
      input_fn=input_fn(params, split='val'),
138
      steps=eval_steps,
139
      hooks=[saver_hook],
140
      start_delay_secs=eval_start_delay_secs,
141
      throttle_secs=eval_throttle_secs)
142

143
  if run_config.is_chief:
144
    outdir = params.model_dir
145
    if outdir is not None:
146
      print('Writing params to %s' % outdir)
147
      os.makedirs(outdir, exist_ok=True)
148
      params.write_yaml(os.path.join(outdir, 'params.yaml'))
149

150
  print('\nRunning estimator')
151
  tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
152

153
  print('\nSaving last model')
154
  ckpt = estimator.latest_checkpoint()
155
  print('Last checkpoint:', ckpt)
156
  ckpt_num_str = ckpt.split('.ckpt-')[1]
157
  fname = os.path.join(params.model_dir, 'saved_model-' + ckpt_num_str)
158
  print('**** Saving last model: %s' % fname)
159
  estimator.export_saved_model(fname, serving_input_receiver_fn())
160

161

162
def main(argv):
163
  if not len(argv) >= 2:
164
    print('Usage: ./trainer.py <config_file, e.g., configs/bottle_0_t5> '
165
          '[model_dir (/tmp/model)]')
166
    exit(0)
167

168
  config_file = argv[1]
169
  if len(argv) > 2:
170
    model_dir = argv[2]
171
  else:
172
    model_dir = '/tmp/model'
173

174
  fname = os.path.join(utils.KEYPOSE_PATH, config_file + '.yaml')
175
  with open(fname, 'r') as f:
176
    params, _, _ = utils.get_params(param_file=f)
177
  dset_dir = os.path.join(utils.KEYPOSE_PATH, params.dset_dir)
178
  # Configuration has the dset directory, now get more info from there.
179
  with open(fname, 'r') as f:
180
    params, _, _ = utils.get_params(
181
        param_file=f,
182
        cam_file=os.path.join(os.path.join(dset_dir, 'data_params.pbtxt')))
183
  params.model_dir = model_dir
184
  params.dset_dir = dset_dir
185

186
  print('Parameters to train and eval:\n', params.make_dict())
187

188
  train_and_eval(
189
      params,
190
      model_fn=est.est_model_fn,
191
      input_fn=inp.create_input_fn,
192
      save_checkpoints_secs=600,
193
      eval_throttle_secs=600,
194
      eval_steps=1000,
195
  )
196

197

198
if __name__ == '__main__':
199
  main(sys.argv)
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.