google-research

Форк
0
130 строк · 4.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for fine-tuning FlaxLM models."""
17

18
import functools
19

20
from flax import jax_utils as flax_jax_utils
21
from flax.training import common_utils
22
import jax
23
import tensorflow.compat.v1 as tf
24
import tqdm
25
from protein_lm import models
26
from protein_lm import utils
27

28
_SHUFFLE_BUFFER_SIZE = 5000
29

30

31
def _get_dataset(sequences, example_weights, batch_size, shuffle):
32
  data_dict = dict(sequence=sequences)
33
  if example_weights is not None:
34
    data_dict['example_weight'] = example_weights
35
  dataset = tf.data.Dataset.from_tensor_slices(data_dict)
36
  if shuffle:
37
    dataset = dataset.shuffle(_SHUFFLE_BUFFER_SIZE)
38
  dataset = dataset.repeat().batch(batch_size)
39
  return dataset
40

41

42
class _OptimizationRunner(object):
43
  """Helper class for running optimization steps."""
44

45
  def __init__(self, model, learning_rate, **optimizer_kwargs):
46
    self._bos_token = model.bos_token
47
    self._pad_token = model.pad_token
48
    unreplicated_optimizer = model.get_weights()
49
    self._replicated_optimizer = utils.create_adam_optimizer(
50
        model=unreplicated_optimizer.target,
51
        learning_rate=learning_rate,
52
        **optimizer_kwargs)
53
    self._dropout_rngs = model._dropout_rngs
54

55
    self._p_train_step = jax.pmap(
56
        functools.partial(
57
            models.train_step,
58
            preprocess_fn=model.preprocess,
59
            learning_rate_fn=lambda t: learning_rate),
60
        axis_name='batch')
61

62
  def fit_batch(self, batch, example_weights=None):
63
    """Runs one optimization step on batch."""
64
    batch = common_utils.shard(batch)
65

66
    if example_weights is not None:
67
      example_weights = common_utils.shard(example_weights)
68
    (self._replicated_optimizer, metrics,
69
     self._dropout_rngs) = self._p_train_step(
70
         self._replicated_optimizer,
71
         inputs=batch,
72
         example_weights=example_weights,
73
         dropout_rng=self._dropout_rngs)
74

75
    return metrics
76

77
  def get_weights(self):
78
    return flax_jax_utils.unreplicate(self._replicated_optimizer)
79

80

81
def fine_tune(model,
82
              initial_weights,
83
              sequences,
84
              batch_size,
85
              num_epochs,
86
              learning_rate,
87
              example_weights=None,
88
              shuffle=True,
89
              progress_bar=True,
90
              **optimizer_kwargs):
91
  """Fine tunes model on sequences.
92

93
  Args:
94
    model: A models.FlaxLM.
95
    initial_weights: The model is initialized with these weights.
96
    sequences: A list of int-encoded sequences to train on.
97
    batch_size: The batch size used when optimizing the model.
98
    num_epochs: Number of passes to take through the input sequences.
99
    learning_rate: Learning rate for optimization.
100
    example_weights: Optional per-sequence weights for performing weighted MLE
101
      training.
102
    shuffle: Whether the input sequences should be shuffled.
103
    progress_bar: Whether to display a progress bar.
104
    **optimizer_kwargs: Additional kwargs to pass to
105
      utils.create_adam_optimizer().
106

107
  Returns:
108
    A set of fine tuned weights. The model can be set to use these using
109
      model.set_weights(fine_tuned_weights).
110
  """
111
  model.set_weights(initial_weights)
112

113
  runner = _OptimizationRunner(
114
      model, learning_rate=learning_rate, **optimizer_kwargs)
115

116
  dataset = _get_dataset(sequences, example_weights, batch_size, shuffle)
117
  dataset_iter = iter(dataset.repeat())
118

119
  num_iter = int(num_epochs * len(sequences) / batch_size)
120
  iterations = list(range(num_iter))
121
  if progress_bar:
122
    iterations = tqdm.tqdm(iterations, position=0)
123
  for _ in iterations:
124
    batch = next(dataset_iter)
125
    batch_example_weights = batch['example_weight'].numpy(
126
    ) if example_weights is not None else None
127
    batch_sequences = batch['sequence'].numpy()
128
    runner.fit_batch(batch_sequences, batch_example_weights)
129

130
  return runner.get_weights()
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.