google-research

Форк
0
140 строк · 4.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Train the Contrack model."""
17

18
import logging
19
import os
20

21
from absl import app
22
from absl import flags
23

24
import tensorflow as tf
25

26
from contrack import data
27
from contrack import encoding
28
from contrack import env
29
from contrack import model
30

31
flags.DEFINE_string('model_path', '',
32
                    'Base output directory where the model is stored.')
33
flags.DEFINE_string('config_path', '', 'File path of config json file.')
34
flags.DEFINE_string(
35
    'config_json', '',
36
    'The contents of a json config file if --config_file was not provided.')
37
flags.DEFINE_string(
38
    'mode', '',
39
    'How to train the model, either "only_new_entities", "only_tracking", "full" or "two_steps".'
40
)
41
flags.DEFINE_string(
42
    'train_data_glob', '',
43
    'A TF glob pattern specifying the location of the training data files.')
44
flags.DEFINE_string(
45
    'eval_data_glob', '',
46
    'A TF glob pattern specifying the location of the validation data files.')
47
FLAGS = flags.FLAGS
48

49

50
def train(argv):
51
  """Train a contrack model."""
52
  del argv  # Unused.
53

54
  mode = FLAGS.mode
55
  if FLAGS.config_path:
56
    config = env.ContrackConfig.load_from_path(FLAGS.config_path)
57
  elif FLAGS.config_json:
58
    config = env.ContrackConfig.load_from_json(FLAGS.config_json)
59
  else:
60
    raise ValueError('Must provide --config_path or --config_json')
61

62
  logging.info('Training with config:\n%s', config)
63
  encodings = encoding.Encodings()
64
  env.Env.init(config, encodings)
65
  environment = env.Env.get()
66

67
  logging.info('Reading training data from %s', FLAGS.train_data_glob)
68
  train_data = data.read_training_data(FLAGS.train_data_glob, config, encodings)
69

70
  if FLAGS.eval_data_glob:
71
    logging.info('Reading validation data from %s', FLAGS.eval_data_glob)
72
    eval_data = data.read_eval_data(FLAGS.eval_data_glob, config, encodings)
73
  else:
74
    eval_data = None
75

76
  tensorboard_dir = os.path.join(FLAGS.model_path, 'tensorboard')
77
  checkpoint_dir = os.path.join(FLAGS.model_path, 'checkpoints')
78
  callbacks = [
79
      tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir),
80
      tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir),
81
      tf.keras.callbacks.TerminateOnNaN()
82
  ]
83

84

85
  # Compile model
86
  if mode == 'only_new_entities' or mode == 'full' or mode == 'only_tracking':
87
    contrack_model = model.ContrackModel(mode)
88
    loss = model.ContrackLoss(mode)
89
    metrics = model.build_metrics(mode)
90
    optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)
91
    contrack_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
92

93
    # Do the actual training
94
    contrack_model.fit(
95
        x=train_data,
96
        epochs=int(config.max_steps / config.steps_per_epoch),
97
        callbacks=callbacks,
98
        steps_per_epoch=config.steps_per_epoch,
99
        validation_data=eval_data)
100
  elif mode == 'two_steps':
101
    logging.info('Training new entity model...')
102
    new_id_model = model.ContrackModel('only_new_entities')
103
    loss = model.ContrackLoss('only_new_entities')
104
    metrics = model.build_metrics('only_new_entities')
105
    optimizer = tf.keras.optimizers.Adam(
106
        learning_rate=config.learning_rate, clipnorm=1.0)
107
    new_id_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
108
    new_id_model.fit(
109
        x=train_data,
110
        epochs=int(config.max_steps / config.steps_per_epoch),
111
        callbacks=callbacks,
112
        steps_per_epoch=config.steps_per_epoch,
113
        validation_data=eval_data)
114

115
    logging.info('Training tracking model...')
116
    contrack_model = model.ContrackModel('only_tracking')
117
    loss = model.ContrackLoss('only_tracking')
118
    metrics = model.build_metrics('only_tracking')
119
    optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)
120
    contrack_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
121
    contrack_model.init_weights_from_new_entity_model(new_id_model)
122
    contrack_model.fit(
123
        x=train_data,
124
        epochs=int(config.max_steps / config.steps_per_epoch),
125
        callbacks=callbacks,
126
        steps_per_epoch=config.steps_per_epoch,
127
        validation_data=eval_data)
128
  else:
129
    raise ValueError('Unknown mode "%s"' % mode)
130

131
  # Save it
132
  filepath = FLAGS.model_path
133
  with tf.keras.utils.custom_object_scope(model.get_custom_objects()):
134
    tf.keras.models.save_model(contrack_model, filepath)
135
    environment.config.save(filepath)
136
    environment.encodings.save(filepath)
137

138

139
if __name__ == '__main__':
140
  app.run(train)
141

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.