google-research
237 строк · 7.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Base training classes.
17"""
18
19import abc
20import functools
21
22import flax
23import jax
24import optax
25
26
27def l2_normalize(
28x,
29axis = -1,
30epsilon = 1e-12,
31):
32"""L2 normalize a tensor on an axis with numerical stability."""
33norm = jax.numpy.linalg.norm(x, ord=2, axis=axis, keepdims=True)
34return x/jax.numpy.maximum(norm, epsilon)
35
36
37class TrainingAlgo(abc.ABC):
38"""Training Algorithm.
39
40Attributes:
41logdir: location of the log directory.
42dataset: tf dataset to train.
43batch_size: batch size for training.
44model: the flax model to train.
45eval_model: the eval/inference version of the flax model.
46learning_rate: the learning rate for training.
47epochs: number of epochs to train for
48params: Optional params to start training from. If None, random params
49are initialized.
50state: Optional state to start training from.
51writer: Writer for writing to tensorboard.
52optimizer: Optimizer for training using gradients.
53optimizer_state: State of the optimizer.
54weight_decay: weight decay coeffecient.
55weigh_decay_mask: Mask to use for weight decay. False values
56means the parameter should be excluded by weight decay. This mask is
57in addition to a mask on batch norm parameters and bias parameters.
58rngs: The PRNGs for model applies.
59"""
60
61# pylint: disable=unused-argument
62def __init__(self,
63logdir,
64dataset,
65batch_size,
66model,
67eval_model,
68learning_rate,
69epochs,
70params=None,
71state=None,
72writer=None,
73weight_decay=0.,
74weight_decay_mask=None,
75rngs=None,
76**kwargs):
77self.logdir = logdir
78self.dataset = dataset
79self.batch_size = batch_size
80self.model = model
81self.eval_model = eval_model
82self.epochs = epochs
83self.params = params
84self.state = state
85self.learning_rate = learning_rate
86self.writer = writer
87
88self.rngs = {'dropout': jax.random.PRNGKey(0)}
89
90batch_norm_mask = jax.tree_map(
91lambda x: not x,
92self.generate_parameter_ancestors(self.params, 'batch_norm'))
93bias_mask = jax.tree_map(
94lambda x: not x, self.is_leaf_name(self.params, 'bias'))
95bias_and_bn_mask = jax.tree_map(lambda x, y: x and y, bias_mask,
96batch_norm_mask)
97
98if weight_decay_mask is None:
99weight_decay_mask = bias_and_bn_mask
100else:
101weight_decay_mask = jax.tree_map(lambda x, y: x and y,
102weight_decay_mask, bias_and_bn_mask)
103
104optimizer = optax.adamw(
105learning_rate=self.learning_rate,
106weight_decay=weight_decay,
107mask=weight_decay_mask,
108)
109self.optimizer = optax.chain(optimizer, optax.zero_nans())
110# pylint: enable=unused-argument
111
112@abc.abstractmethod
113def _loss(self, *args, **kwargs):
114"""Loss function that calls model using params.
115
116Should return the scalar loss as the first value, and a tuple of
117other auxilary values, such as the updated model state.
118
119Args:
120*args: Positional arguments.
121**kwargs: Keyword arguments.
122"""
123
124@functools.partial(jax.jit, static_argnums=(0,))
125def loss(self, *args, **kwargs):
126"""Jitted version of the private loss."""
127return self._loss(*args, **kwargs)
128
129@functools.partial(jax.jit, static_argnums=(0,))
130def update_model(self, params, gradients, optimizer_state):
131updates, optimizer_state = self.optimizer.update(gradients,
132optimizer_state,
133params=params)
134params = optax.apply_updates(params, updates)
135
136return params, optimizer_state
137
138def get_grad_fn(self,):
139return jax.jit(jax.grad(self.loss, has_aux=True))
140
141@abc.abstractmethod
142def run(self,):
143"""Runs a training algorithm through a dataset for a fixed number of epochs.
144
145Returns:
146params: Parameters after training
147state: Model state of training.
148"""
149
150def update_rngs(self,):
151"""Updates the rngs with new values."""
152new_rngs = {}
153for k, rng in self.rngs.items():
154rng, _ = jax.random.split(rng, 2)
155new_rngs[k] = rng
156self.rngs = new_rngs
157
158def generate_parameter_ancestors(self, params, name):
159"""Returns a Pytree inidicated if the leaf is has an ancestor with name.
160
161Has the same structure as params, except each leaf is a boolean value
162where True indicates the parameter is a parameter with name as an ancestor.
163This is useful for identifying parameters that should be excluded from
164weight decay.
165
166Args:
167params: A FrozenDict of parameter values.
168name: The name to match.
169"""
170flattened = flax.traverse_util.flatten_dict(params.unfreeze())
171flattened_mask = {k: True if any([name in pname for pname in k]) else False
172for k in flattened.keys()}
173mask = flax.core.FrozenDict(
174flax.traverse_util.unflatten_dict(flattened_mask))
175return mask
176
177def is_leaf_name(self, params, name):
178"""Returns a Pytree inidicated if the leaf is named name.
179
180Has the same structure as params, except each leaf is a boolean value
181where True indicates the parameter has name.
182This is useful for identifying parameters that should be excluded from
183weight decay.
184
185Args:
186params: A FrozenDict of parameter values.
187name: The name to match.
188"""
189flattened = flax.traverse_util.flatten_dict(params.unfreeze())
190flattened_mask = {k: True if k[-1] == name else False
191for k in flattened.keys()}
192mask = flax.core.FrozenDict(
193flax.traverse_util.unflatten_dict(flattened_mask))
194return mask
195
196
197class PretextTrainingAlgo(TrainingAlgo):
198"""Pretext Training Algo.
199
200Takes care of generating the weight decay masks for pretext parameters.
201"""
202
203def __init__(self,
204logdir,
205dataset,
206batch_size,
207model,
208eval_model,
209learning_rate,
210epochs,
211params=None,
212state=None,
213writer=None,
214weight_decay=0.,
215weight_decay_mask=None,
216patience=32,):
217# Only apply weight decay to pretext parameters.
218pretext_mask = self.generate_parameter_ancestors(params, 'pretext')
219super(PretextTrainingAlgo, self).__init__(
220logdir,
221dataset,
222batch_size,
223model,
224eval_model,
225learning_rate,
226epochs,
227params=params,
228state=state,
229writer=writer,
230weight_decay=weight_decay,
231weight_decay_mask=pretext_mask,
232)
233self.patience = patience
234self.early_stop_params = self.params
235self.early_stop_state = self.state
236self.best_early_stop_loss = float('inf')
237self.patience_counter = 0
238