google-research
199 строк · 6.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Keypose estimator training / eval functions.
17
18Adapted from 'Discovery of Latent 3D Keypoints via End-to-end
19Geometric Reasoning' keypoint network.
20
21Given a 2D image and viewpoint, predict a set of 3D keypoints that
22match the target examples.
23
24Can be instance or class specific, depending on training set.
25
26Typical invocation:
27$ python3 -m keypose.trainer.py configs/bottle_0_t5 /tmp/model
28"""
29
30import os
31import sys
32
33from tensorflow import estimator as tf_estimator
34from tensorflow import keras
35
36from keypose import estimator as est
37from keypose import inputs as inp
38from keypose import utils
39
40
41def train_and_eval(params,
42model_fn,
43input_fn,
44keep_checkpoint_every_n_hours=0.5,
45save_checkpoints_secs=100,
46eval_steps=0,
47eval_start_delay_secs=10,
48eval_throttle_secs=100,
49save_summary_steps=50):
50"""Trains and evaluates our model.
51
52Supports local and distributed training.
53
54Args:
55params: ConfigParams class with model training and network parameters.
56model_fn: A func with prototype model_fn(features, labels, mode, hparams).
57input_fn: A input function for the tf.estimator.Estimator.
58keep_checkpoint_every_n_hours: Number of hours between each checkpoint to be
59saved.
60save_checkpoints_secs: Save checkpoints every this many seconds.
61eval_steps: Number of steps to evaluate model; 0 for one epoch.
62eval_start_delay_secs: Start evaluating after waiting for this many seconds.
63eval_throttle_secs: Do not re-evaluate unless the last evaluation was
64started at least this many seconds ago
65save_summary_steps: Save summaries every this many steps.
66"""
67
68mparams = params.model_params
69
70run_config = tf_estimator.RunConfig(
71keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
72save_checkpoints_secs=save_checkpoints_secs,
73save_summary_steps=save_summary_steps)
74
75if run_config.model_dir:
76params.model_dir = run_config.model_dir
77print('\nCreating estimator with model dir %s' % params.model_dir)
78estimator = tf_estimator.Estimator(
79model_fn=model_fn,
80model_dir=params.model_dir,
81config=run_config,
82params=params)
83
84print('\nCreating train_spec')
85train_spec = tf_estimator.TrainSpec(
86input_fn=input_fn(params, split='train'), max_steps=params.steps)
87
88print('\nCreating eval_spec')
89
90def serving_input_receiver_fn():
91"""Serving input_fn that builds features from placeholders.
92
93Returns:
94A tf.estimator.export.ServingInputReceiver.
95"""
96modelx = mparams.modelx
97modely = mparams.modely
98offsets = keras.Input(shape=(3,), name='offsets', dtype='float32')
99hom = keras.Input(shape=(3, 3), name='hom', dtype='float32')
100to_world = keras.Input(shape=(4, 4), name='to_world_L', dtype='float32')
101img_l = keras.Input(
102shape=(modely, modelx, 3), name='img_L', dtype='float32')
103img_r = keras.Input(
104shape=(modely, modelx, 3), name='img_R', dtype='float32')
105features = {
106'img_L': img_l,
107'img_R': img_r,
108'to_world_L': to_world,
109'offsets': offsets,
110'hom': hom
111}
112return tf_estimator.export.build_raw_serving_input_receiver_fn(features)
113
114class SaveModel(tf_estimator.SessionRunHook):
115"""Saves a model in SavedModel format."""
116
117def __init__(self, estimator, output_dir):
118self.output_dir = output_dir
119self.estimator = estimator
120self.save_num = 0
121
122def begin(self):
123ckpt = self.estimator.latest_checkpoint()
124print('Latest checkpoint in hook:', ckpt)
125ckpt_num_str = ckpt.split('.ckpt-')[1]
126if (int(ckpt_num_str) - self.save_num) > 4000:
127fname = os.path.join(self.output_dir, 'saved_model-' + ckpt_num_str)
128print('**** Saving model in train hook: %s' % fname)
129self.estimator.export_saved_model(fname, serving_input_receiver_fn())
130self.save_num = int(ckpt_num_str)
131
132saver_hook = SaveModel(estimator, params.model_dir)
133
134if eval_steps == 0:
135eval_steps = None
136eval_spec = tf_estimator.EvalSpec(
137input_fn=input_fn(params, split='val'),
138steps=eval_steps,
139hooks=[saver_hook],
140start_delay_secs=eval_start_delay_secs,
141throttle_secs=eval_throttle_secs)
142
143if run_config.is_chief:
144outdir = params.model_dir
145if outdir is not None:
146print('Writing params to %s' % outdir)
147os.makedirs(outdir, exist_ok=True)
148params.write_yaml(os.path.join(outdir, 'params.yaml'))
149
150print('\nRunning estimator')
151tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
152
153print('\nSaving last model')
154ckpt = estimator.latest_checkpoint()
155print('Last checkpoint:', ckpt)
156ckpt_num_str = ckpt.split('.ckpt-')[1]
157fname = os.path.join(params.model_dir, 'saved_model-' + ckpt_num_str)
158print('**** Saving last model: %s' % fname)
159estimator.export_saved_model(fname, serving_input_receiver_fn())
160
161
162def main(argv):
163if not len(argv) >= 2:
164print('Usage: ./trainer.py <config_file, e.g., configs/bottle_0_t5> '
165'[model_dir (/tmp/model)]')
166exit(0)
167
168config_file = argv[1]
169if len(argv) > 2:
170model_dir = argv[2]
171else:
172model_dir = '/tmp/model'
173
174fname = os.path.join(utils.KEYPOSE_PATH, config_file + '.yaml')
175with open(fname, 'r') as f:
176params, _, _ = utils.get_params(param_file=f)
177dset_dir = os.path.join(utils.KEYPOSE_PATH, params.dset_dir)
178# Configuration has the dset directory, now get more info from there.
179with open(fname, 'r') as f:
180params, _, _ = utils.get_params(
181param_file=f,
182cam_file=os.path.join(os.path.join(dset_dir, 'data_params.pbtxt')))
183params.model_dir = model_dir
184params.dset_dir = dset_dir
185
186print('Parameters to train and eval:\n', params.make_dict())
187
188train_and_eval(
189params,
190model_fn=est.est_model_fn,
191input_fn=inp.create_input_fn,
192save_checkpoints_secs=600,
193eval_throttle_secs=600,
194eval_steps=1000,
195)
196
197
198if __name__ == '__main__':
199main(sys.argv)
200