google-research

Форк
0
237 строк · 7.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Base training classes.
17
"""
18

19
import abc
20
import functools
21

22
import flax
23
import jax
24
import optax
25

26

27
def l2_normalize(
28
    x,
29
    axis = -1,
30
    epsilon = 1e-12,
31
):
32
  """L2 normalize a tensor on an axis with numerical stability."""
33
  norm = jax.numpy.linalg.norm(x, ord=2, axis=axis, keepdims=True)
34
  return x/jax.numpy.maximum(norm, epsilon)
35

36

37
class TrainingAlgo(abc.ABC):
38
  """Training Algorithm.
39

40
  Attributes:
41
    logdir: location of the log directory.
42
    dataset: tf dataset to train.
43
    batch_size: batch size for training.
44
    model: the flax model to train.
45
    eval_model: the eval/inference version of the flax model.
46
    learning_rate: the learning rate for training.
47
    epochs: number of epochs to train for
48
    params: Optional params to start training from.  If None, random params
49
      are initialized.
50
    state: Optional state to start training from.
51
    writer: Writer for writing to tensorboard.
52
    optimizer: Optimizer for training using gradients.
53
    optimizer_state: State of the optimizer.
54
    weight_decay: weight decay coeffecient.
55
    weigh_decay_mask: Mask to use for weight decay. False values
56
      means the parameter should be excluded by weight decay.  This mask is
57
      in addition to a mask on batch norm parameters and bias parameters.
58
    rngs: The PRNGs for model applies.
59
  """
60

61
  # pylint: disable=unused-argument
62
  def __init__(self,
63
               logdir,
64
               dataset,
65
               batch_size,
66
               model,
67
               eval_model,
68
               learning_rate,
69
               epochs,
70
               params=None,
71
               state=None,
72
               writer=None,
73
               weight_decay=0.,
74
               weight_decay_mask=None,
75
               rngs=None,
76
               **kwargs):
77
    self.logdir = logdir
78
    self.dataset = dataset
79
    self.batch_size = batch_size
80
    self.model = model
81
    self.eval_model = eval_model
82
    self.epochs = epochs
83
    self.params = params
84
    self.state = state
85
    self.learning_rate = learning_rate
86
    self.writer = writer
87

88
    self.rngs = {'dropout': jax.random.PRNGKey(0)}
89

90
    batch_norm_mask = jax.tree_map(
91
        lambda x: not x,
92
        self.generate_parameter_ancestors(self.params, 'batch_norm'))
93
    bias_mask = jax.tree_map(
94
        lambda x: not x, self.is_leaf_name(self.params, 'bias'))
95
    bias_and_bn_mask = jax.tree_map(lambda x, y: x and y, bias_mask,
96
                                    batch_norm_mask)
97

98
    if weight_decay_mask is None:
99
      weight_decay_mask = bias_and_bn_mask
100
    else:
101
      weight_decay_mask = jax.tree_map(lambda x, y: x and y,
102
                                       weight_decay_mask, bias_and_bn_mask)
103

104
    optimizer = optax.adamw(
105
        learning_rate=self.learning_rate,
106
        weight_decay=weight_decay,
107
        mask=weight_decay_mask,
108
    )
109
    self.optimizer = optax.chain(optimizer, optax.zero_nans())
110
  # pylint: enable=unused-argument
111

112
  @abc.abstractmethod
113
  def _loss(self, *args, **kwargs):
114
    """Loss function that calls model using params.
115

116
    Should return the scalar loss as the first value, and a tuple of
117
    other auxilary values, such as the updated model state.
118

119
    Args:
120
      *args: Positional arguments.
121
      **kwargs: Keyword arguments.
122
    """
123

124
  @functools.partial(jax.jit, static_argnums=(0,))
125
  def loss(self, *args, **kwargs):
126
    """Jitted version of the private loss."""
127
    return self._loss(*args, **kwargs)
128

129
  @functools.partial(jax.jit, static_argnums=(0,))
130
  def update_model(self, params, gradients, optimizer_state):
131
    updates, optimizer_state = self.optimizer.update(gradients,
132
                                                     optimizer_state,
133
                                                     params=params)
134
    params = optax.apply_updates(params, updates)
135

136
    return params, optimizer_state
137

138
  def get_grad_fn(self,):
139
    return jax.jit(jax.grad(self.loss, has_aux=True))
140

141
  @abc.abstractmethod
142
  def run(self,):
143
    """Runs a training algorithm through a dataset for a fixed number of epochs.
144

145
    Returns:
146
      params: Parameters after training
147
      state: Model state of training.
148
    """
149

150
  def update_rngs(self,):
151
    """Updates the rngs with new values."""
152
    new_rngs = {}
153
    for k, rng in self.rngs.items():
154
      rng, _ = jax.random.split(rng, 2)
155
      new_rngs[k] = rng
156
    self.rngs = new_rngs
157

158
  def generate_parameter_ancestors(self, params, name):
159
    """Returns a Pytree inidicated if the leaf is has an ancestor with name.
160

161
    Has the same structure as params, except each leaf is a boolean value
162
    where True indicates the parameter is a parameter with name as an ancestor.
163
    This is useful for identifying parameters that should be excluded from
164
    weight decay.
165

166
    Args:
167
      params: A FrozenDict of parameter values.
168
      name: The name to match.
169
    """
170
    flattened = flax.traverse_util.flatten_dict(params.unfreeze())
171
    flattened_mask = {k: True if any([name in pname for pname in k]) else False
172
                      for k in flattened.keys()}
173
    mask = flax.core.FrozenDict(
174
        flax.traverse_util.unflatten_dict(flattened_mask))
175
    return mask
176

177
  def is_leaf_name(self, params, name):
178
    """Returns a Pytree inidicated if the leaf is named name.
179

180
    Has the same structure as params, except each leaf is a boolean value
181
    where True indicates the parameter has name.
182
    This is useful for identifying parameters that should be excluded from
183
    weight decay.
184

185
    Args:
186
      params: A FrozenDict of parameter values.
187
      name: The name to match.
188
    """
189
    flattened = flax.traverse_util.flatten_dict(params.unfreeze())
190
    flattened_mask = {k: True if k[-1] == name else False
191
                      for k in flattened.keys()}
192
    mask = flax.core.FrozenDict(
193
        flax.traverse_util.unflatten_dict(flattened_mask))
194
    return mask
195

196

197
class PretextTrainingAlgo(TrainingAlgo):
198
  """Pretext Training Algo.
199

200
  Takes care of generating the weight decay masks for pretext parameters.
201
  """
202

203
  def __init__(self,
204
               logdir,
205
               dataset,
206
               batch_size,
207
               model,
208
               eval_model,
209
               learning_rate,
210
               epochs,
211
               params=None,
212
               state=None,
213
               writer=None,
214
               weight_decay=0.,
215
               weight_decay_mask=None,
216
               patience=32,):
217
    # Only apply weight decay to pretext parameters.
218
    pretext_mask = self.generate_parameter_ancestors(params, 'pretext')
219
    super(PretextTrainingAlgo, self).__init__(
220
        logdir,
221
        dataset,
222
        batch_size,
223
        model,
224
        eval_model,
225
        learning_rate,
226
        epochs,
227
        params=params,
228
        state=state,
229
        writer=writer,
230
        weight_decay=weight_decay,
231
        weight_decay_mask=pretext_mask,
232
    )
233
    self.patience = patience
234
    self.early_stop_params = self.params
235
    self.early_stop_state = self.state
236
    self.best_early_stop_loss = float('inf')
237
    self.patience_counter = 0
238

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.