google-research

Форк
0
/
simsiam_pretext_training.py 
240 строк · 7.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""SimSiam pretext.
17

18
https://arxiv.org/pdf/2011.10566.pdf
19
"""
20

21
from absl import logging
22
from flax.core import freeze
23
import jax
24
import jax.numpy as jnp
25

26
from q_match.algorithms.training_algo import l2_normalize
27
from q_match.algorithms.training_algo import PretextTrainingAlgo
28
from q_match.algorithms.vime_pretext_training import vime_corruption
29

30

31
@jax.jit
32
def negative_cosine_similarity(x, y):
33
  return -(x * y).sum(axis=-1).mean()
34

35

36
@jax.jit
37
def simsiam_loss(proj_1, proj_2, pred_1, pred_2,):
38
  """Simsiam loss.
39

40
  Args:
41
    proj_1:  view 1 proj
42
    proj_2:  view 2 proj.
43
    pred_1:  view 1 pred
44
    pred_2:  view 2 pred.
45

46

47
  Returns:
48
    Loss value
49
  """
50
  proj_1 = jax.lax.stop_gradient(proj_1)
51
  proj_2 = jax.lax.stop_gradient(proj_2)
52

53
  proj_1 = l2_normalize(proj_1)
54
  proj_2 = l2_normalize(proj_2)
55
  pred_1 = l2_normalize(pred_1)
56
  pred_2 = l2_normalize(pred_2)
57

58
  loss_a = negative_cosine_similarity(proj_1, pred_2)
59
  loss_b = negative_cosine_similarity(proj_2, pred_1)
60

61
  return jnp.mean(loss_a + loss_b) / 2
62

63

64
class SimSiamPretextTraining(PretextTrainingAlgo):
65
  """SimSiam Training Algorithm.
66

67
  Attributes:
68
    logdir: location of the log directory.
69
    dataset: tf dataset to train.
70
    batch_size: batch size for training.
71
    model: Dictionary of models that includes
72
    learning_rate: the learning rate for training.
73
    epochs: number of epochs to train for
74
    params: Optional params to start training from.  If None, random params
75
      are initialized.
76
    state: Optional state to start training from.
77
    writer: Writer for writing to tensorboard.
78
    support_set_size: Size of the support set. if zero, batch mode is
79
      used instead.
80
    batch_mode: Whether to use batch mode.
81
    use_mse_loss: whether to use MSE loss instead of log loss.
82
    support_init_key: support set initialization key.
83
    weight_decay: weight decay on pretext params.
84
    corruption_p: The probability of corrupting for view1
85
    query_corruption_p: The probability for corruption for view 2
86
    student_temperature: Student temperature in distribution match loss.
87
  """
88

89
  def __init__(
90
      self,
91
      logdir,
92
      dataset,
93
      batch_size,
94
      model,
95
      eval_model,
96
      learning_rate,
97
      epochs,
98
      params=None,
99
      state=None,
100
      writer=None,
101
      weight_decay=0.,
102
      corruption_p=.3,
103
      patience=32,
104
      **kwargs
105
  ):
106

107
    super(SimSiamPretextTraining,
108
          self).__init__(logdir, dataset, batch_size, model, eval_model,
109
                         learning_rate, epochs, params, state, writer,
110
                         weight_decay, patience=patience)
111

112
    self.mask_key = jax.random.PRNGKey(99)
113
    self.corruption_p = corruption_p
114

115
  def _loss(self, params, state,
116
            features, mask_key,
117
            ):
118
    """Loss with siam siam."""
119

120
    variables = freeze({'params': params, **state})
121

122
    ## View 1
123
    view_1_features, _ = vime_corruption(features, self.corruption_p,
124
                                         mask_key)
125
    ## View 2
126
    # Use the first key later, so pick second.
127
    _, new_mask_key = jax.random.split(self.mask_key)
128
    view_2_features, _ = vime_corruption(
129
        features, p=self.corruption_p, mask_key=new_mask_key)
130

131
    # View 1 Projection and Predictor
132
    output_1, updated_state = self.model.apply(variables,
133
                                               view_1_features,
134
                                               mutable=['batch_stats'],
135
                                               rngs=self.rngs)
136

137
    proj_1 = output_1['pretext']['siam_proj']
138
    pred_1 = output_1['pretext']['siam_pred']
139

140
    # View 2 Projection and Predictor
141
    output_2, updated_state = self.model.apply(variables,
142
                                               view_2_features,
143
                                               mutable=['batch_stats'],
144
                                               rngs=self.rngs)
145
    proj_2 = output_2['pretext']['siam_proj']
146
    pred_2 = output_2['pretext']['siam_pred']
147

148
    pretext_loss = simsiam_loss(proj_1=proj_1, proj_2=proj_2,
149
                                pred_1=pred_1, pred_2=pred_2)
150

151
    return pretext_loss, updated_state
152

153
  def run(self,):
154
    """Runs a pretext training algo."""
155
    params = self.params
156
    state = self.state
157
    dataset = self.dataset
158
    model = self.model
159

160
    example_data = jax.numpy.array(dataset.get_example_features())
161
    variables = freeze({'params': params, **state})
162
    example_output, _ = model.apply(variables,
163
                                    example_data,
164
                                    mutable=['batch_stats'],
165
                                    rngs=self.rngs,
166
                                    )
167
    logging.debug(str(example_output))
168

169
    optimizer_state = self.optimizer.init(params=params)
170

171
    grad_fn = self.get_grad_fn()
172

173
    steps = 0
174
    for epoch in range(self.epochs):
175
      logging.info('Pretext Epoch: %d', epoch)
176
      for example in dataset.get_pretext_ds():
177
        features = jax.numpy.array(example['features'])
178

179
        if steps % 100 == 0:
180
          pretext_loss, _ = self.loss(
181
              params, state, features,
182
              self.mask_key,
183
              )
184
          log_train_loss_msg = f'pretext training loss {pretext_loss}'
185
          logging.info(log_train_loss_msg)
186

187
          metrics = {'pretext_train_loss': pretext_loss,}
188

189
          if self.writer is not None:
190
            self.writer.write_scalars(steps, metrics)
191

192
        gradients, state = grad_fn(params, state,
193
                                   features, self.mask_key,
194
                                   )
195

196
        params, optimizer_state = self.update_model(params,
197
                                                    gradients,
198
                                                    optimizer_state)
199
        self.update_rngs()
200
        self.mask_key, _ = jax.random.split(self.mask_key)
201
        steps += 1
202

203
      # # check validation pretext dataset if it exists
204
      pretext_validation_ds = dataset.get_pretext_validation_ds()
205
      if pretext_validation_ds is not None:
206
        # compute validation loss
207
        validation_loss = 0.
208
        val_seen = 0
209
        val_mask_key = self.mask_key
210
        for example in pretext_validation_ds:
211
          features = jax.numpy.array(example['features'])
212
          seen = features.shape[0]
213
          validation_loss += self.loss(
214
              params,
215
              state,
216
              features,
217
              val_mask_key,
218
              )[0] * seen
219
          val_seen += seen
220
          val_mask_key, _ = jax.random.split(val_mask_key)
221
        validation_loss /= float(val_seen)
222

223
        self.writer.write_scalars(
224
            epoch,
225
            {'pretext_validation_loss': validation_loss})
226
        if validation_loss < self.best_early_stop_loss:
227
          self.best_early_stop_loss = validation_loss
228
          self.early_stop_params = params
229
          self.early_stop_state = state
230
          self.patience_counter = 0
231
        else:
232
          self.patience_counter += 1
233

234
        if self.patience_counter > self.patience:
235
          break
236
      else:
237
        self.early_stop_params = params
238
        self.early_stop_state = state
239

240
    return self.early_stop_params, self.early_stop_state
241

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.