google-research

Форк
0
/
simclr_pretext_training.py 
274 строки · 8.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""SimCLR.
17

18
https://arxiv.org/pdf/2002.05709.pdf
19
"""
20

21
from absl import logging
22
from flax.core import freeze
23
import jax
24
import jax.numpy as jnp
25
import optax
26

27
from q_match.algorithms.training_algo import l2_normalize
28
from q_match.algorithms.training_algo import PretextTrainingAlgo
29
from q_match.algorithms.vime_pretext_training import vime_corruption
30

31
LARGE_NUM = 1e9
32

33

34
@jax.jit
35
def modified_loss(z_a,
36
                  z_b,
37
                  temperature=.1):
38
  """Modified loss.
39

40
  We use the diagonal elements of z_a @ z_b.T as the positives, and the
41
  off-diagonal elements as the negatives.
42

43
  A big difference is that the augmented pairs of the same example are not
44
  included as a negatives in this implementation.
45

46

47
  Args:
48
    z_a:  projection from view 1 (n x d).
49
    z_b:  projection from view 2 (n x d).
50
    temperature: Temperature for softmax.
51

52
  Returns:
53
    Loss value
54
  """
55
  z_a = l2_normalize(z_a)
56
  z_b = l2_normalize(z_b)
57

58
  similarity_matrix = jnp.exp(
59
      z_a @ z_b.T / temperature
60
  )  # (n x d) (d x n) = n x n
61

62
  mask = jnp.eye(similarity_matrix.shape[0], similarity_matrix.shape[1])
63

64
  positives = jnp.sum(similarity_matrix * mask, axis=1)  # (n, )
65
  negatives = jnp.sum(similarity_matrix * (1. - mask))  # (n, )
66

67
  logits = positives / negatives
68

69
  return jnp.mean(-jnp.log(logits))
70

71

72
@jax.jit
73
def simclr_loss(z_a, z_b, temperature=.1):
74
  """SimCLR loss."""
75
  batch_size = z_a.shape[0]
76
  labels = jax.nn.one_hot(jax.numpy.arange(start=0, stop=batch_size),
77
                          batch_size * 2)  # (n x 2n)
78
  masks = jax.nn.one_hot(jax.numpy.arange(start=0, stop=batch_size),
79
                         batch_size)  # (n x n)
80
  logits_aa = (z_a @ z_a.T) / temperature  # (n x n)
81
  logits_aa = logits_aa - masks * LARGE_NUM
82
  logits_bb = (z_b @ z_b.T) / temperature
83
  logits_bb = logits_bb - masks * LARGE_NUM
84
  logits_ab = z_a @ z_b.T / temperature
85
  logits_ba = z_b @ z_a.T / temperature
86

87
  logits_a = jax.numpy.concatenate([logits_ab, logits_aa], axis=1)  # (n x 2n)
88
  logits_b = jax.numpy.concatenate([logits_ba, logits_bb], axis=1)  # (n x 2n)
89
  loss_a = optax.softmax_cross_entropy(logits_a, labels)  # (n,)
90
  loss_b = optax.softmax_cross_entropy(logits_b, labels)  # (n,)
91

92
  return jax.numpy.mean(loss_a + loss_b)
93

94

95
class SimCLRPretextTraining(PretextTrainingAlgo):
96
  """SimCLR Training Algorithm.
97

98
  Attributes:
99
    logdir: location of the log directory.
100
    dataset: tf dataset to train.
101
    batch_size: batch size for training.
102
    model: Dictionary of models that includes
103
    learning_rate: the learning rate for training.
104
    epochs: number of epochs to train for
105
    params: Optional params to start training from.  If None, random params
106
      are initialized.
107
    state: Optional state to start training from.
108
    writer: Writer for writing to tensorboard.
109
    support_set_size: Size of the support set. if zero, batch mode is
110
      used instead.
111
    batch_mode: Whether to use batch mode.
112
    use_mse_loss: whether to use MSE loss instead of log loss.
113
    support_init_key: support set initialization key.
114
    weight_decay: weight decay on pretext params.
115
    corruption_p: The probability of corrupting for view1
116
    query_corruption_p: The probability for corruption for view 2
117
    student_temperature: Student temperature in distribution match loss.
118
    use_modified_loss: Boolean to use the modified loss function.
119
  """
120

121
  def __init__(
122
      self,
123
      logdir,
124
      dataset,
125
      batch_size,
126
      model,
127
      eval_model,
128
      learning_rate,
129
      epochs,
130
      params=None,
131
      state=None,
132
      writer=None,
133
      weight_decay=0.,
134
      corruption_p=.3,
135
      patience=32,
136
      use_modified_loss=True,
137
      **kwargs
138
  ):
139

140
    super(SimCLRPretextTraining,
141
          self).__init__(logdir, dataset, batch_size, model, eval_model,
142
                         learning_rate, epochs, params, state, writer,
143
                         weight_decay, patience=patience)
144

145
    self.mask_key = jax.random.PRNGKey(99)
146
    self.corruption_p = corruption_p
147
    self.use_modified_loss = use_modified_loss
148

149
  def _loss(self, params, state,
150
            features, mask_key,
151
            ):
152
    """Loss with siam siam."""
153

154
    variables = freeze({'params': params, **state})
155

156
    ## View 1
157
    view_1_features, _ = vime_corruption(features, self.corruption_p,
158
                                         mask_key)
159
    ## View 2
160
    # Use the first key later, so pick second.
161
    _, new_mask_key = jax.random.split(self.mask_key)
162
    view_2_features, _ = vime_corruption(
163
        features, p=self.corruption_p, mask_key=new_mask_key)
164

165
    # View 1 Encoded
166
    output_1, updated_state = self.model.apply(variables,
167
                                               view_1_features,
168
                                               mutable=['batch_stats'],
169
                                               rngs=self.rngs)
170

171
    proj_1 = output_1['pretext']['proj']
172

173
    # View 2 Encoded
174
    output_2, updated_state = self.model.apply(variables,
175
                                               view_2_features,
176
                                               mutable=['batch_stats'],
177
                                               rngs=self.rngs)
178
    proj_2 = output_2['pretext']['proj']
179

180
    if self.use_modified_loss:
181
      pretext_loss = modified_loss(z_a=proj_1, z_b=proj_2)
182
    else:
183
      pretext_loss = simclr_loss(z_a=proj_1, z_b=proj_2)
184

185
    return pretext_loss, updated_state
186

187
  def run(self,):
188
    """Runs a pretext training algo."""
189
    params = self.params
190
    state = self.state
191
    dataset = self.dataset
192
    model = self.model
193

194
    example_data = jax.numpy.array(dataset.get_example_features())
195
    variables = freeze({'params': params, **state})
196
    example_output, _ = model.apply(variables,
197
                                    example_data,
198
                                    mutable=['batch_stats'],
199
                                    rngs=self.rngs,
200
                                    )
201
    logging.debug(str(example_output))
202

203
    optimizer_state = self.optimizer.init(params=params)
204

205
    grad_fn = self.get_grad_fn()
206

207
    steps = 0
208
    for epoch in range(self.epochs):
209
      logging.info('Pretext Epoch: %d', epoch)
210
      for example in dataset.get_pretext_ds():
211
        features = jax.numpy.array(example['features'])
212

213
        if steps % 100 == 0:
214
          pretext_loss, _ = self.loss(
215
              params, state, features,
216
              self.mask_key,
217
              )
218
          log_train_loss_msg = f'pretext training loss {pretext_loss}'
219
          logging.info(log_train_loss_msg)
220

221
          metrics = {'pretext_train_loss': pretext_loss,}
222

223
          if self.writer is not None:
224
            self.writer.write_scalars(steps, metrics)
225

226
        gradients, state = grad_fn(params, state,
227
                                   features, self.mask_key,
228
                                   )
229

230
        params, optimizer_state = self.update_model(params,
231
                                                    gradients,
232
                                                    optimizer_state)
233
        self.update_rngs()
234
        self.mask_key, _ = jax.random.split(self.mask_key)
235
        steps += 1
236

237
      # # check validation pretext dataset if it exists
238
      pretext_validation_ds = dataset.get_pretext_validation_ds()
239
      if pretext_validation_ds is not None:
240
        # compute validation loss
241
        validation_loss = 0.
242
        val_seen = 0
243
        val_mask_key = self.mask_key
244
        for example in pretext_validation_ds:
245
          features = jax.numpy.array(example['features'])
246
          seen = features.shape[0]
247
          validation_loss += self.loss(
248
              params,
249
              state,
250
              features,
251
              val_mask_key,
252
              )[0] * seen
253
          val_seen += seen
254
          val_mask_key, _ = jax.random.split(val_mask_key)
255
        validation_loss /= float(val_seen)
256

257
        self.writer.write_scalars(
258
            epoch,
259
            {'pretext_validation_loss': validation_loss})
260
        if validation_loss < self.best_early_stop_loss:
261
          self.best_early_stop_loss = validation_loss
262
          self.early_stop_params = params
263
          self.early_stop_state = state
264
          self.patience_counter = 0
265
        else:
266
          self.patience_counter += 1
267

268
        if self.patience_counter > self.patience:
269
          break
270
      else:
271
        self.early_stop_params = params
272
        self.early_stop_state = state
273

274
    return self.early_stop_params, self.early_stop_state
275

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.