google-research
140 строк · 4.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Train the Contrack model."""
17
18import logging19import os20
21from absl import app22from absl import flags23
24import tensorflow as tf25
26from contrack import data27from contrack import encoding28from contrack import env29from contrack import model30
31flags.DEFINE_string('model_path', '',32'Base output directory where the model is stored.')33flags.DEFINE_string('config_path', '', 'File path of config json file.')34flags.DEFINE_string(35'config_json', '',36'The contents of a json config file if --config_file was not provided.')37flags.DEFINE_string(38'mode', '',39'How to train the model, either "only_new_entities", "only_tracking", "full" or "two_steps".'40)
41flags.DEFINE_string(42'train_data_glob', '',43'A TF glob pattern specifying the location of the training data files.')44flags.DEFINE_string(45'eval_data_glob', '',46'A TF glob pattern specifying the location of the validation data files.')47FLAGS = flags.FLAGS48
49
50def train(argv):51"""Train a contrack model."""52del argv # Unused.53
54mode = FLAGS.mode55if FLAGS.config_path:56config = env.ContrackConfig.load_from_path(FLAGS.config_path)57elif FLAGS.config_json:58config = env.ContrackConfig.load_from_json(FLAGS.config_json)59else:60raise ValueError('Must provide --config_path or --config_json')61
62logging.info('Training with config:\n%s', config)63encodings = encoding.Encodings()64env.Env.init(config, encodings)65environment = env.Env.get()66
67logging.info('Reading training data from %s', FLAGS.train_data_glob)68train_data = data.read_training_data(FLAGS.train_data_glob, config, encodings)69
70if FLAGS.eval_data_glob:71logging.info('Reading validation data from %s', FLAGS.eval_data_glob)72eval_data = data.read_eval_data(FLAGS.eval_data_glob, config, encodings)73else:74eval_data = None75
76tensorboard_dir = os.path.join(FLAGS.model_path, 'tensorboard')77checkpoint_dir = os.path.join(FLAGS.model_path, 'checkpoints')78callbacks = [79tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir),80tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir),81tf.keras.callbacks.TerminateOnNaN()82]83
84
85# Compile model86if mode == 'only_new_entities' or mode == 'full' or mode == 'only_tracking':87contrack_model = model.ContrackModel(mode)88loss = model.ContrackLoss(mode)89metrics = model.build_metrics(mode)90optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)91contrack_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)92
93# Do the actual training94contrack_model.fit(95x=train_data,96epochs=int(config.max_steps / config.steps_per_epoch),97callbacks=callbacks,98steps_per_epoch=config.steps_per_epoch,99validation_data=eval_data)100elif mode == 'two_steps':101logging.info('Training new entity model...')102new_id_model = model.ContrackModel('only_new_entities')103loss = model.ContrackLoss('only_new_entities')104metrics = model.build_metrics('only_new_entities')105optimizer = tf.keras.optimizers.Adam(106learning_rate=config.learning_rate, clipnorm=1.0)107new_id_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)108new_id_model.fit(109x=train_data,110epochs=int(config.max_steps / config.steps_per_epoch),111callbacks=callbacks,112steps_per_epoch=config.steps_per_epoch,113validation_data=eval_data)114
115logging.info('Training tracking model...')116contrack_model = model.ContrackModel('only_tracking')117loss = model.ContrackLoss('only_tracking')118metrics = model.build_metrics('only_tracking')119optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)120contrack_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)121contrack_model.init_weights_from_new_entity_model(new_id_model)122contrack_model.fit(123x=train_data,124epochs=int(config.max_steps / config.steps_per_epoch),125callbacks=callbacks,126steps_per_epoch=config.steps_per_epoch,127validation_data=eval_data)128else:129raise ValueError('Unknown mode "%s"' % mode)130
131# Save it132filepath = FLAGS.model_path133with tf.keras.utils.custom_object_scope(model.get_custom_objects()):134tf.keras.models.save_model(contrack_model, filepath)135environment.config.save(filepath)136environment.encodings.save(filepath)137
138
139if __name__ == '__main__':140app.run(train)141