google-research
240 строк · 7.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""SimSiam pretext.
17
18https://arxiv.org/pdf/2011.10566.pdf
19"""
20
21from absl import logging
22from flax.core import freeze
23import jax
24import jax.numpy as jnp
25
26from q_match.algorithms.training_algo import l2_normalize
27from q_match.algorithms.training_algo import PretextTrainingAlgo
28from q_match.algorithms.vime_pretext_training import vime_corruption
29
30
31@jax.jit
32def negative_cosine_similarity(x, y):
33return -(x * y).sum(axis=-1).mean()
34
35
36@jax.jit
37def simsiam_loss(proj_1, proj_2, pred_1, pred_2,):
38"""Simsiam loss.
39
40Args:
41proj_1: view 1 proj
42proj_2: view 2 proj.
43pred_1: view 1 pred
44pred_2: view 2 pred.
45
46
47Returns:
48Loss value
49"""
50proj_1 = jax.lax.stop_gradient(proj_1)
51proj_2 = jax.lax.stop_gradient(proj_2)
52
53proj_1 = l2_normalize(proj_1)
54proj_2 = l2_normalize(proj_2)
55pred_1 = l2_normalize(pred_1)
56pred_2 = l2_normalize(pred_2)
57
58loss_a = negative_cosine_similarity(proj_1, pred_2)
59loss_b = negative_cosine_similarity(proj_2, pred_1)
60
61return jnp.mean(loss_a + loss_b) / 2
62
63
64class SimSiamPretextTraining(PretextTrainingAlgo):
65"""SimSiam Training Algorithm.
66
67Attributes:
68logdir: location of the log directory.
69dataset: tf dataset to train.
70batch_size: batch size for training.
71model: Dictionary of models that includes
72learning_rate: the learning rate for training.
73epochs: number of epochs to train for
74params: Optional params to start training from. If None, random params
75are initialized.
76state: Optional state to start training from.
77writer: Writer for writing to tensorboard.
78support_set_size: Size of the support set. if zero, batch mode is
79used instead.
80batch_mode: Whether to use batch mode.
81use_mse_loss: whether to use MSE loss instead of log loss.
82support_init_key: support set initialization key.
83weight_decay: weight decay on pretext params.
84corruption_p: The probability of corrupting for view1
85query_corruption_p: The probability for corruption for view 2
86student_temperature: Student temperature in distribution match loss.
87"""
88
89def __init__(
90self,
91logdir,
92dataset,
93batch_size,
94model,
95eval_model,
96learning_rate,
97epochs,
98params=None,
99state=None,
100writer=None,
101weight_decay=0.,
102corruption_p=.3,
103patience=32,
104**kwargs
105):
106
107super(SimSiamPretextTraining,
108self).__init__(logdir, dataset, batch_size, model, eval_model,
109learning_rate, epochs, params, state, writer,
110weight_decay, patience=patience)
111
112self.mask_key = jax.random.PRNGKey(99)
113self.corruption_p = corruption_p
114
115def _loss(self, params, state,
116features, mask_key,
117):
118"""Loss with siam siam."""
119
120variables = freeze({'params': params, **state})
121
122## View 1
123view_1_features, _ = vime_corruption(features, self.corruption_p,
124mask_key)
125## View 2
126# Use the first key later, so pick second.
127_, new_mask_key = jax.random.split(self.mask_key)
128view_2_features, _ = vime_corruption(
129features, p=self.corruption_p, mask_key=new_mask_key)
130
131# View 1 Projection and Predictor
132output_1, updated_state = self.model.apply(variables,
133view_1_features,
134mutable=['batch_stats'],
135rngs=self.rngs)
136
137proj_1 = output_1['pretext']['siam_proj']
138pred_1 = output_1['pretext']['siam_pred']
139
140# View 2 Projection and Predictor
141output_2, updated_state = self.model.apply(variables,
142view_2_features,
143mutable=['batch_stats'],
144rngs=self.rngs)
145proj_2 = output_2['pretext']['siam_proj']
146pred_2 = output_2['pretext']['siam_pred']
147
148pretext_loss = simsiam_loss(proj_1=proj_1, proj_2=proj_2,
149pred_1=pred_1, pred_2=pred_2)
150
151return pretext_loss, updated_state
152
153def run(self,):
154"""Runs a pretext training algo."""
155params = self.params
156state = self.state
157dataset = self.dataset
158model = self.model
159
160example_data = jax.numpy.array(dataset.get_example_features())
161variables = freeze({'params': params, **state})
162example_output, _ = model.apply(variables,
163example_data,
164mutable=['batch_stats'],
165rngs=self.rngs,
166)
167logging.debug(str(example_output))
168
169optimizer_state = self.optimizer.init(params=params)
170
171grad_fn = self.get_grad_fn()
172
173steps = 0
174for epoch in range(self.epochs):
175logging.info('Pretext Epoch: %d', epoch)
176for example in dataset.get_pretext_ds():
177features = jax.numpy.array(example['features'])
178
179if steps % 100 == 0:
180pretext_loss, _ = self.loss(
181params, state, features,
182self.mask_key,
183)
184log_train_loss_msg = f'pretext training loss {pretext_loss}'
185logging.info(log_train_loss_msg)
186
187metrics = {'pretext_train_loss': pretext_loss,}
188
189if self.writer is not None:
190self.writer.write_scalars(steps, metrics)
191
192gradients, state = grad_fn(params, state,
193features, self.mask_key,
194)
195
196params, optimizer_state = self.update_model(params,
197gradients,
198optimizer_state)
199self.update_rngs()
200self.mask_key, _ = jax.random.split(self.mask_key)
201steps += 1
202
203# # check validation pretext dataset if it exists
204pretext_validation_ds = dataset.get_pretext_validation_ds()
205if pretext_validation_ds is not None:
206# compute validation loss
207validation_loss = 0.
208val_seen = 0
209val_mask_key = self.mask_key
210for example in pretext_validation_ds:
211features = jax.numpy.array(example['features'])
212seen = features.shape[0]
213validation_loss += self.loss(
214params,
215state,
216features,
217val_mask_key,
218)[0] * seen
219val_seen += seen
220val_mask_key, _ = jax.random.split(val_mask_key)
221validation_loss /= float(val_seen)
222
223self.writer.write_scalars(
224epoch,
225{'pretext_validation_loss': validation_loss})
226if validation_loss < self.best_early_stop_loss:
227self.best_early_stop_loss = validation_loss
228self.early_stop_params = params
229self.early_stop_state = state
230self.patience_counter = 0
231else:
232self.patience_counter += 1
233
234if self.patience_counter > self.patience:
235break
236else:
237self.early_stop_params = params
238self.early_stop_state = state
239
240return self.early_stop_params, self.early_stop_state
241