google-research
274 строки · 8.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""SimCLR.
17
18https://arxiv.org/pdf/2002.05709.pdf
19"""
20
21from absl import logging
22from flax.core import freeze
23import jax
24import jax.numpy as jnp
25import optax
26
27from q_match.algorithms.training_algo import l2_normalize
28from q_match.algorithms.training_algo import PretextTrainingAlgo
29from q_match.algorithms.vime_pretext_training import vime_corruption
30
31LARGE_NUM = 1e9
32
33
34@jax.jit
35def modified_loss(z_a,
36z_b,
37temperature=.1):
38"""Modified loss.
39
40We use the diagonal elements of z_a @ z_b.T as the positives, and the
41off-diagonal elements as the negatives.
42
43A big difference is that the augmented pairs of the same example are not
44included as a negatives in this implementation.
45
46
47Args:
48z_a: projection from view 1 (n x d).
49z_b: projection from view 2 (n x d).
50temperature: Temperature for softmax.
51
52Returns:
53Loss value
54"""
55z_a = l2_normalize(z_a)
56z_b = l2_normalize(z_b)
57
58similarity_matrix = jnp.exp(
59z_a @ z_b.T / temperature
60) # (n x d) (d x n) = n x n
61
62mask = jnp.eye(similarity_matrix.shape[0], similarity_matrix.shape[1])
63
64positives = jnp.sum(similarity_matrix * mask, axis=1) # (n, )
65negatives = jnp.sum(similarity_matrix * (1. - mask)) # (n, )
66
67logits = positives / negatives
68
69return jnp.mean(-jnp.log(logits))
70
71
72@jax.jit
73def simclr_loss(z_a, z_b, temperature=.1):
74"""SimCLR loss."""
75batch_size = z_a.shape[0]
76labels = jax.nn.one_hot(jax.numpy.arange(start=0, stop=batch_size),
77batch_size * 2) # (n x 2n)
78masks = jax.nn.one_hot(jax.numpy.arange(start=0, stop=batch_size),
79batch_size) # (n x n)
80logits_aa = (z_a @ z_a.T) / temperature # (n x n)
81logits_aa = logits_aa - masks * LARGE_NUM
82logits_bb = (z_b @ z_b.T) / temperature
83logits_bb = logits_bb - masks * LARGE_NUM
84logits_ab = z_a @ z_b.T / temperature
85logits_ba = z_b @ z_a.T / temperature
86
87logits_a = jax.numpy.concatenate([logits_ab, logits_aa], axis=1) # (n x 2n)
88logits_b = jax.numpy.concatenate([logits_ba, logits_bb], axis=1) # (n x 2n)
89loss_a = optax.softmax_cross_entropy(logits_a, labels) # (n,)
90loss_b = optax.softmax_cross_entropy(logits_b, labels) # (n,)
91
92return jax.numpy.mean(loss_a + loss_b)
93
94
95class SimCLRPretextTraining(PretextTrainingAlgo):
96"""SimCLR Training Algorithm.
97
98Attributes:
99logdir: location of the log directory.
100dataset: tf dataset to train.
101batch_size: batch size for training.
102model: Dictionary of models that includes
103learning_rate: the learning rate for training.
104epochs: number of epochs to train for
105params: Optional params to start training from. If None, random params
106are initialized.
107state: Optional state to start training from.
108writer: Writer for writing to tensorboard.
109support_set_size: Size of the support set. if zero, batch mode is
110used instead.
111batch_mode: Whether to use batch mode.
112use_mse_loss: whether to use MSE loss instead of log loss.
113support_init_key: support set initialization key.
114weight_decay: weight decay on pretext params.
115corruption_p: The probability of corrupting for view1
116query_corruption_p: The probability for corruption for view 2
117student_temperature: Student temperature in distribution match loss.
118use_modified_loss: Boolean to use the modified loss function.
119"""
120
121def __init__(
122self,
123logdir,
124dataset,
125batch_size,
126model,
127eval_model,
128learning_rate,
129epochs,
130params=None,
131state=None,
132writer=None,
133weight_decay=0.,
134corruption_p=.3,
135patience=32,
136use_modified_loss=True,
137**kwargs
138):
139
140super(SimCLRPretextTraining,
141self).__init__(logdir, dataset, batch_size, model, eval_model,
142learning_rate, epochs, params, state, writer,
143weight_decay, patience=patience)
144
145self.mask_key = jax.random.PRNGKey(99)
146self.corruption_p = corruption_p
147self.use_modified_loss = use_modified_loss
148
149def _loss(self, params, state,
150features, mask_key,
151):
152"""Loss with siam siam."""
153
154variables = freeze({'params': params, **state})
155
156## View 1
157view_1_features, _ = vime_corruption(features, self.corruption_p,
158mask_key)
159## View 2
160# Use the first key later, so pick second.
161_, new_mask_key = jax.random.split(self.mask_key)
162view_2_features, _ = vime_corruption(
163features, p=self.corruption_p, mask_key=new_mask_key)
164
165# View 1 Encoded
166output_1, updated_state = self.model.apply(variables,
167view_1_features,
168mutable=['batch_stats'],
169rngs=self.rngs)
170
171proj_1 = output_1['pretext']['proj']
172
173# View 2 Encoded
174output_2, updated_state = self.model.apply(variables,
175view_2_features,
176mutable=['batch_stats'],
177rngs=self.rngs)
178proj_2 = output_2['pretext']['proj']
179
180if self.use_modified_loss:
181pretext_loss = modified_loss(z_a=proj_1, z_b=proj_2)
182else:
183pretext_loss = simclr_loss(z_a=proj_1, z_b=proj_2)
184
185return pretext_loss, updated_state
186
187def run(self,):
188"""Runs a pretext training algo."""
189params = self.params
190state = self.state
191dataset = self.dataset
192model = self.model
193
194example_data = jax.numpy.array(dataset.get_example_features())
195variables = freeze({'params': params, **state})
196example_output, _ = model.apply(variables,
197example_data,
198mutable=['batch_stats'],
199rngs=self.rngs,
200)
201logging.debug(str(example_output))
202
203optimizer_state = self.optimizer.init(params=params)
204
205grad_fn = self.get_grad_fn()
206
207steps = 0
208for epoch in range(self.epochs):
209logging.info('Pretext Epoch: %d', epoch)
210for example in dataset.get_pretext_ds():
211features = jax.numpy.array(example['features'])
212
213if steps % 100 == 0:
214pretext_loss, _ = self.loss(
215params, state, features,
216self.mask_key,
217)
218log_train_loss_msg = f'pretext training loss {pretext_loss}'
219logging.info(log_train_loss_msg)
220
221metrics = {'pretext_train_loss': pretext_loss,}
222
223if self.writer is not None:
224self.writer.write_scalars(steps, metrics)
225
226gradients, state = grad_fn(params, state,
227features, self.mask_key,
228)
229
230params, optimizer_state = self.update_model(params,
231gradients,
232optimizer_state)
233self.update_rngs()
234self.mask_key, _ = jax.random.split(self.mask_key)
235steps += 1
236
237# # check validation pretext dataset if it exists
238pretext_validation_ds = dataset.get_pretext_validation_ds()
239if pretext_validation_ds is not None:
240# compute validation loss
241validation_loss = 0.
242val_seen = 0
243val_mask_key = self.mask_key
244for example in pretext_validation_ds:
245features = jax.numpy.array(example['features'])
246seen = features.shape[0]
247validation_loss += self.loss(
248params,
249state,
250features,
251val_mask_key,
252)[0] * seen
253val_seen += seen
254val_mask_key, _ = jax.random.split(val_mask_key)
255validation_loss /= float(val_seen)
256
257self.writer.write_scalars(
258epoch,
259{'pretext_validation_loss': validation_loss})
260if validation_loss < self.best_early_stop_loss:
261self.best_early_stop_loss = validation_loss
262self.early_stop_params = params
263self.early_stop_state = state
264self.patience_counter = 0
265else:
266self.patience_counter += 1
267
268if self.patience_counter > self.patience:
269break
270else:
271self.early_stop_params = params
272self.early_stop_state = state
273
274return self.early_stop_params, self.early_stop_state
275