google-research

Форк
0
217 строк · 7.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Train a transformer on dataset of sequences."""
17

18
import contextlib
19
import os
20
import time
21

22
from absl import app
23
from absl import flags
24
from absl import logging
25
import gin
26
import jax
27
import jax.nn
28
import jax.numpy as jnp
29
import tensorflow.compat.v1 as tf
30

31
from protein_lm import data
32
from protein_lm import evaluation
33
from protein_lm import logging as logging_lib
34
from protein_lm import models
35

36
tf_summary = logging_lib.tf_summary
37

38
FLAGS = flags.FLAGS
39

40
flags.DEFINE_string(
41
    'work_dir', default=None, help=('Directory to store model data.'))
42

43
flags.DEFINE_multi_string('gin_files', [], 'List of paths to the config files.')
44

45
flags.DEFINE_multi_string('gin_bindings', [],
46
                          'Newline separated list of Gin parameter bindings.')
47

48

49
def _write_gin_configs(output_file):
50
  """Writes current gin configs to `output_file`."""
51
  config_str = gin.operative_config_str()
52
  logging.info('=' * 80)
53
  logging.info('Gin configs\n%s', config_str)
54
  logging.info('=' * 80)
55
  with tf.gfile.GFile(output_file, 'w') as f:
56
    f.write(config_str)
57

58

59
@gin.configurable('experiment')
60
def run_experiment(
61
    model_dir,
62
    data_dir=None,
63
    xid=None,
64
    batch_size_per_device=128,
65
    eval_frequency=500,
66
    checkpoint_frequency=10000,
67
    save_checkpoints=True,
68
    restore_checkpoint=True,
69
    num_eval_steps=None,
70
    epochs=None,
71
    max_train_steps=1000000,  # 1 million
72
    max_train_length=512,
73
    train_summary_frequency=100,
74
    max_eval_length=None,
75
    model_cls=models.FlaxLM):
76
  """Run experiment.
77

78
  Args:
79
    model_dir: Directory to save checkpoints and metrics to.
80
    data_dir: Directory to load data.
81
    xid: Optional experiment id.
82
    batch_size_per_device: Batch size per device.
83
    eval_frequency: Steps per eval.
84
    checkpoint_frequency: How often to checkpoint. If None, only checkpoint once
85
      at end of run.
86
    save_checkpoints: If True, checkpoints model according to
87
      checkpoint_frequency
88
    restore_checkpoint: If True, will restore checkpoint from directory. Useful
89
      for robustness to preemption.
90
    num_eval_steps: Number of eval steps to take on eval dataset.
91
    epochs: Number of train epochs.
92
    max_train_steps: Stop training after N steps.
93
    max_train_length: Crop training sequences to this length.
94
    train_summary_frequency: Frequency to write train metrics.
95
    max_eval_length: Maximum eval length. Defaults to max_train_length.
96
    model_cls: Model class to use.
97

98
  Returns:
99
    FlaxLM resulting from running training.
100
  """
101
  if xid is not None:
102
    model_dir = os.path.join(model_dir, '%s_l%s' % (str(xid), max_train_length))
103
  tf.enable_v2_behavior()
104
  if jax.host_id() == 0:
105
    summary_writer = tf_summary.create_file_writer(
106
        os.path.join(model_dir, 'metrics'), max_queue=1, flush_millis=1000)
107
    train_summary_writer = logging_lib.ScalarSummary(
108
        step=None,
109
        scope='train/',
110
        enable_tf=True,
111
        verbose=0)
112
    eval_summary_writer = logging_lib.ScalarSummary(
113
        step=None,
114
        scope='eval/',
115
        enable_tf=True,
116
        verbose=0)
117

118
  batch_size = batch_size_per_device * jax.local_device_count()
119
  max_eval_length = max_eval_length or max_train_length
120
  train_files, test_files = data.get_train_valid_files(directory=data_dir)
121
  train_ds, eval_ds = data.load_dataset(
122
      train_files=train_files,
123
      test_files=test_files,
124
      batch_size=batch_size,
125
      max_train_length=max_train_length,
126
      max_eval_length=max_eval_length,
127
      shuffle_buffer=16384)
128

129
  with contextlib.ExitStack() as stack:  # pylint: disable=using-constant-test
130
    if jax.host_id() == 0:
131
      # Only need metric writer context manager on host 0.
132
      stack.enter_context(summary_writer.as_default())
133
    model = model_cls(domain=data.protein_domain, batch_size=batch_size)
134

135
    if restore_checkpoint:
136
      try:
137
        model.load_checkpoint(model_dir)
138
      except ValueError:
139
        # No checkpoint to load -> raises ValueError.
140
        pass
141
    start_step = model.train_step
142

143
    train_ds = train_ds.repeat(epochs)
144
    train_iter = iter(train_ds)
145
    train_metrics = []
146
    tick = time.time()
147

148
    if jax.host_id() == 0:
149
      _write_gin_configs(os.path.join(model_dir, 'config.gin'))
150

151
    num_evals = 0
152
    for step, batch in zip(range(start_step, max_train_steps), train_iter):
153
      batch = jax.tree_map(lambda x: x._numpy(), batch)  # pylint: disable=protected-access
154
      metrics = model.fit_batch(batch)
155
      train_metrics.append(metrics)
156

157
      if jax.host_id() == 0 and ((save_checkpoints and checkpoint_frequency and
158
                                  step % checkpoint_frequency == 0 and step > 0)
159
                                 or step == max_train_steps - 1):
160
        model.save_checkpoint(model_dir)
161

162
      if (step + 1) % train_summary_frequency == 0:
163
        summary = evaluation.combine_metrics(train_metrics)
164
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])
165
        if jax.host_id() == 0:
166
          tock = time.time()
167
          steps_per_sec = eval_frequency / (tock - tick)
168
          tick = tock
169
          train_summary_writer('steps per second', steps_per_sec, step)
170
          for key, val in summary.items():
171
            if jnp.isnan(val):
172
              raise ValueError(f'NaN in {key} at step {step}.')
173
            train_summary_writer(key, val, step)
174

175
        # reset metric accumulation for next evaluation cycle.
176
        train_metrics = []
177

178
      if eval_frequency and (step + 1) % eval_frequency == 0:
179
        eval_summary = evaluation.evaluate(
180
            model=model, eval_ds=eval_ds, num_eval_steps=num_eval_steps)
181

182
        logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss'])
183
        if jax.host_id() == 0:
184
          for key, val in eval_summary.items():
185
            eval_summary_writer(key, val, step)
186
          tf_summary.flush()
187
          summary_writer.flush()
188

189
          if num_evals == 0:
190
            # Write out config on first eval.
191
            _write_gin_configs(os.path.join(model_dir, 'config_after_eval.gin'))
192
          num_evals += 1
193

194
  if jax.host_id() == 0:
195
    tf_summary.flush()
196
    summary_writer.close()
197
    _write_gin_configs(os.path.join(model_dir, 'config_end.gin'))
198
  return model
199

200

201
def main(argv):
202
  if len(argv) > 1:
203
    raise app.UsageError('Too many command-line arguments.')
204
  logging.info('Main called')
205

206
  gin_bindings = FLAGS.gin_bindings
207
  gin_files = FLAGS.gin_files
208

209
  # Parse gin configs.
210
  logging.info('Gin files: %s', str(gin_files))
211
  logging.info('Gin bindings: %s', str(gin_bindings))
212
  gin.parse_config_files_and_bindings(gin_files, gin_bindings)
213
  run_experiment(model_dir=FLAGS.work_dir)
214

215

216
if __name__ == '__main__':
217
  app.run(main)
218

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.