google-research
130 строк · 4.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for fine-tuning FlaxLM models."""
17
18import functools
19
20from flax import jax_utils as flax_jax_utils
21from flax.training import common_utils
22import jax
23import tensorflow.compat.v1 as tf
24import tqdm
25from protein_lm import models
26from protein_lm import utils
27
28_SHUFFLE_BUFFER_SIZE = 5000
29
30
31def _get_dataset(sequences, example_weights, batch_size, shuffle):
32data_dict = dict(sequence=sequences)
33if example_weights is not None:
34data_dict['example_weight'] = example_weights
35dataset = tf.data.Dataset.from_tensor_slices(data_dict)
36if shuffle:
37dataset = dataset.shuffle(_SHUFFLE_BUFFER_SIZE)
38dataset = dataset.repeat().batch(batch_size)
39return dataset
40
41
42class _OptimizationRunner(object):
43"""Helper class for running optimization steps."""
44
45def __init__(self, model, learning_rate, **optimizer_kwargs):
46self._bos_token = model.bos_token
47self._pad_token = model.pad_token
48unreplicated_optimizer = model.get_weights()
49self._replicated_optimizer = utils.create_adam_optimizer(
50model=unreplicated_optimizer.target,
51learning_rate=learning_rate,
52**optimizer_kwargs)
53self._dropout_rngs = model._dropout_rngs
54
55self._p_train_step = jax.pmap(
56functools.partial(
57models.train_step,
58preprocess_fn=model.preprocess,
59learning_rate_fn=lambda t: learning_rate),
60axis_name='batch')
61
62def fit_batch(self, batch, example_weights=None):
63"""Runs one optimization step on batch."""
64batch = common_utils.shard(batch)
65
66if example_weights is not None:
67example_weights = common_utils.shard(example_weights)
68(self._replicated_optimizer, metrics,
69self._dropout_rngs) = self._p_train_step(
70self._replicated_optimizer,
71inputs=batch,
72example_weights=example_weights,
73dropout_rng=self._dropout_rngs)
74
75return metrics
76
77def get_weights(self):
78return flax_jax_utils.unreplicate(self._replicated_optimizer)
79
80
81def fine_tune(model,
82initial_weights,
83sequences,
84batch_size,
85num_epochs,
86learning_rate,
87example_weights=None,
88shuffle=True,
89progress_bar=True,
90**optimizer_kwargs):
91"""Fine tunes model on sequences.
92
93Args:
94model: A models.FlaxLM.
95initial_weights: The model is initialized with these weights.
96sequences: A list of int-encoded sequences to train on.
97batch_size: The batch size used when optimizing the model.
98num_epochs: Number of passes to take through the input sequences.
99learning_rate: Learning rate for optimization.
100example_weights: Optional per-sequence weights for performing weighted MLE
101training.
102shuffle: Whether the input sequences should be shuffled.
103progress_bar: Whether to display a progress bar.
104**optimizer_kwargs: Additional kwargs to pass to
105utils.create_adam_optimizer().
106
107Returns:
108A set of fine tuned weights. The model can be set to use these using
109model.set_weights(fine_tuned_weights).
110"""
111model.set_weights(initial_weights)
112
113runner = _OptimizationRunner(
114model, learning_rate=learning_rate, **optimizer_kwargs)
115
116dataset = _get_dataset(sequences, example_weights, batch_size, shuffle)
117dataset_iter = iter(dataset.repeat())
118
119num_iter = int(num_epochs * len(sequences) / batch_size)
120iterations = list(range(num_iter))
121if progress_bar:
122iterations = tqdm.tqdm(iterations, position=0)
123for _ in iterations:
124batch = next(dataset_iter)
125batch_example_weights = batch['example_weight'].numpy(
126) if example_weights is not None else None
127batch_sequences = batch['sequence'].numpy()
128runner.fit_batch(batch_sequences, batch_example_weights)
129
130return runner.get_weights()
131