google-research
217 строк · 7.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Train a transformer on dataset of sequences."""
17
18import contextlib19import os20import time21
22from absl import app23from absl import flags24from absl import logging25import gin26import jax27import jax.nn28import jax.numpy as jnp29import tensorflow.compat.v1 as tf30
31from protein_lm import data32from protein_lm import evaluation33from protein_lm import logging as logging_lib34from protein_lm import models35
36tf_summary = logging_lib.tf_summary37
38FLAGS = flags.FLAGS39
40flags.DEFINE_string(41'work_dir', default=None, help=('Directory to store model data.'))42
43flags.DEFINE_multi_string('gin_files', [], 'List of paths to the config files.')44
45flags.DEFINE_multi_string('gin_bindings', [],46'Newline separated list of Gin parameter bindings.')47
48
49def _write_gin_configs(output_file):50"""Writes current gin configs to `output_file`."""51config_str = gin.operative_config_str()52logging.info('=' * 80)53logging.info('Gin configs\n%s', config_str)54logging.info('=' * 80)55with tf.gfile.GFile(output_file, 'w') as f:56f.write(config_str)57
58
59@gin.configurable('experiment')60def run_experiment(61model_dir,62data_dir=None,63xid=None,64batch_size_per_device=128,65eval_frequency=500,66checkpoint_frequency=10000,67save_checkpoints=True,68restore_checkpoint=True,69num_eval_steps=None,70epochs=None,71max_train_steps=1000000, # 1 million72max_train_length=512,73train_summary_frequency=100,74max_eval_length=None,75model_cls=models.FlaxLM):76"""Run experiment.77
78Args:
79model_dir: Directory to save checkpoints and metrics to.
80data_dir: Directory to load data.
81xid: Optional experiment id.
82batch_size_per_device: Batch size per device.
83eval_frequency: Steps per eval.
84checkpoint_frequency: How often to checkpoint. If None, only checkpoint once
85at end of run.
86save_checkpoints: If True, checkpoints model according to
87checkpoint_frequency
88restore_checkpoint: If True, will restore checkpoint from directory. Useful
89for robustness to preemption.
90num_eval_steps: Number of eval steps to take on eval dataset.
91epochs: Number of train epochs.
92max_train_steps: Stop training after N steps.
93max_train_length: Crop training sequences to this length.
94train_summary_frequency: Frequency to write train metrics.
95max_eval_length: Maximum eval length. Defaults to max_train_length.
96model_cls: Model class to use.
97
98Returns:
99FlaxLM resulting from running training.
100"""
101if xid is not None:102model_dir = os.path.join(model_dir, '%s_l%s' % (str(xid), max_train_length))103tf.enable_v2_behavior()104if jax.host_id() == 0:105summary_writer = tf_summary.create_file_writer(106os.path.join(model_dir, 'metrics'), max_queue=1, flush_millis=1000)107train_summary_writer = logging_lib.ScalarSummary(108step=None,109scope='train/',110enable_tf=True,111verbose=0)112eval_summary_writer = logging_lib.ScalarSummary(113step=None,114scope='eval/',115enable_tf=True,116verbose=0)117
118batch_size = batch_size_per_device * jax.local_device_count()119max_eval_length = max_eval_length or max_train_length120train_files, test_files = data.get_train_valid_files(directory=data_dir)121train_ds, eval_ds = data.load_dataset(122train_files=train_files,123test_files=test_files,124batch_size=batch_size,125max_train_length=max_train_length,126max_eval_length=max_eval_length,127shuffle_buffer=16384)128
129with contextlib.ExitStack() as stack: # pylint: disable=using-constant-test130if jax.host_id() == 0:131# Only need metric writer context manager on host 0.132stack.enter_context(summary_writer.as_default())133model = model_cls(domain=data.protein_domain, batch_size=batch_size)134
135if restore_checkpoint:136try:137model.load_checkpoint(model_dir)138except ValueError:139# No checkpoint to load -> raises ValueError.140pass141start_step = model.train_step142
143train_ds = train_ds.repeat(epochs)144train_iter = iter(train_ds)145train_metrics = []146tick = time.time()147
148if jax.host_id() == 0:149_write_gin_configs(os.path.join(model_dir, 'config.gin'))150
151num_evals = 0152for step, batch in zip(range(start_step, max_train_steps), train_iter):153batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access154metrics = model.fit_batch(batch)155train_metrics.append(metrics)156
157if jax.host_id() == 0 and ((save_checkpoints and checkpoint_frequency and158step % checkpoint_frequency == 0 and step > 0)159or step == max_train_steps - 1):160model.save_checkpoint(model_dir)161
162if (step + 1) % train_summary_frequency == 0:163summary = evaluation.combine_metrics(train_metrics)164logging.info('train in step: %d, loss: %.4f', step, summary['loss'])165if jax.host_id() == 0:166tock = time.time()167steps_per_sec = eval_frequency / (tock - tick)168tick = tock169train_summary_writer('steps per second', steps_per_sec, step)170for key, val in summary.items():171if jnp.isnan(val):172raise ValueError(f'NaN in {key} at step {step}.')173train_summary_writer(key, val, step)174
175# reset metric accumulation for next evaluation cycle.176train_metrics = []177
178if eval_frequency and (step + 1) % eval_frequency == 0:179eval_summary = evaluation.evaluate(180model=model, eval_ds=eval_ds, num_eval_steps=num_eval_steps)181
182logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss'])183if jax.host_id() == 0:184for key, val in eval_summary.items():185eval_summary_writer(key, val, step)186tf_summary.flush()187summary_writer.flush()188
189if num_evals == 0:190# Write out config on first eval.191_write_gin_configs(os.path.join(model_dir, 'config_after_eval.gin'))192num_evals += 1193
194if jax.host_id() == 0:195tf_summary.flush()196summary_writer.close()197_write_gin_configs(os.path.join(model_dir, 'config_end.gin'))198return model199
200
201def main(argv):202if len(argv) > 1:203raise app.UsageError('Too many command-line arguments.')204logging.info('Main called')205
206gin_bindings = FLAGS.gin_bindings207gin_files = FLAGS.gin_files208
209# Parse gin configs.210logging.info('Gin files: %s', str(gin_files))211logging.info('Gin bindings: %s', str(gin_bindings))212gin.parse_config_files_and_bindings(gin_files, gin_bindings)213run_experiment(model_dir=FLAGS.work_dir)214
215
216if __name__ == '__main__':217app.run(main)218