google-research

Форк
0
284 строки · 10.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main train and eval loop for SimCLR+linear layer experiments.
17

18
Given an existing trained SimCLR model, trains a linear layer on top to predict
19
the original latents from dsprites dataset.
20
"""
21

22
from absl import app
23
from absl import flags
24
from absl import logging
25

26
from simclr.tf2 import lars_optimizer as lars
27
import tensorflow.compat.v2 as tf
28

29
import graph_compression.contrastive_learning.data_utils.learning_latents as data_lib  # pylint: disable=unused-import
30
import graph_compression.contrastive_learning.datasets.learning_latents as datasets_lib
31
import graph_compression.contrastive_learning.metrics_utils.learning_latents as metrics_lib
32
import graph_compression.contrastive_learning.models.learning_latents as models_lib
33

34

35
FLAGS = flags.FLAGS
36

37
USE_TPU = flags.DEFINE_boolean('use_tpu', False, 'For TPU training.')
38

39
TPU_ADDRESS = flags.DEFINE_string('tpu_address', None,
40
                                  'Manually specify a TPU address.')
41

42
MASTER = flags.DEFINE_string('master', '',
43
                             'Required for compatibility, leave blank.')
44

45
LR = flags.DEFINE_float('learning_rate', 1e-1,
46
                        'Learning rate for linear layer.')
47

48
L2 = flags.DEFINE_float('l2_penalty', 1e-4, 'Penalty for L2 regularization.')
49

50
T_BATCHSIZE = flags.DEFINE_integer('train_batch_size', 512,
51
                                   'Batch size for training.')
52

53
T_STEPS_PER_LOOP = flags.DEFINE_integer(
54
    'train_steps_per_loop', 5,
55
    'How many train steps to run between metrics summaries updates.')
56

57
TOTAL_STEPS = flags.DEFINE_integer('total_steps', 1,
58
                                   'Number of steps to train for.')
59

60
DATA_DIR = flags.DEFINE_string('data_dir', None, 'Directory to log data to.')
61

62
IMG_SIZE = flags.DEFINE_list(
63
    'img_size', None,
64
    'Optional image rescaling (comma separated list representing [new_height, new_width]).'
65
)
66

67
NUM_CHANNELS = flags.DEFINE_integer(
68
    'num_channels', None, 'Optional image tiling to multiple channels.')
69

70
PRETRAINED_MODEL_PATH = flags.DEFINE_string('pretrained_model_path', None,
71
                                            'Path to saved pretrained model.')
72

73
EVAL_SPLIT = flags.DEFINE_float('eval_split', 0.1,
74
                                'Fraction of dataset to use for eval.')
75

76
E_BATCHSIZE = flags.DEFINE_integer('eval_batch_size', 2048,
77
                                   'Batch size for eval.')
78

79
E_FREQ = flags.DEFINE_integer('eval_frequency', 5,
80
                              'How often to run eval loop.')
81

82
SEED = flags.DEFINE_integer('seed', None, 'Specify a random seed.')
83

84

85
def main(argv):
86
  if len(argv) > 1:
87
    raise app.UsageError('Too many command-line arguments.')
88

89

90
  # set up tpu strategy
91
  if USE_TPU.value:
92
    tpu_address = TPU_ADDRESS.value or MASTER.value
93

94
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
95
        tpu=tpu_address)
96
    tf.config.experimental_connect_to_cluster(cluster_resolver)
97
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
98
    strategy = tf.distribute.TPUStrategy(cluster_resolver)
99

100
  else:
101
    # no-op strategy: for debugging / running on one machine
102
    strategy = tf.distribute.get_strategy()
103

104

105
  # actual run
106

107
  with strategy.scope():
108

109
    if IMG_SIZE.value is not None:
110
      img_size = tf.convert_to_tensor([int(v) for v in IMG_SIZE.value])
111

112
    # set up dataset
113
    dataset = datasets_lib.get_standard_dataset(
114
        name='dsprites',
115
        img_size=img_size,
116
        num_channels=NUM_CHANNELS.value,
117
        eval_split=EVAL_SPLIT.value,
118
        seed=SEED.value)
119
    train_df, train_ds, num_train_examples = dataset['train']
120
    eval_df, eval_ds, num_eval_examples = dataset['eval']
121

122
    del train_df, eval_df  ## not used here
123

124
    logging.info('Train, eval sets contain %s, %s elements respectively',
125
                 num_train_examples, num_eval_examples)
126
    for x in train_ds.take(1):
127
      num_classes = x['values'].shape[0]
128
    logging.info('Num classes is %s', num_classes)
129

130
    logging.info('Setting up datasets...')
131
    t_batchsize, e_batchsize = T_BATCHSIZE.value, E_BATCHSIZE.value
132
    train_ds = train_ds.shuffle(
133
        buffer_size=t_batchsize * 10, reshuffle_each_iteration=True)
134
    # drop the final partial batch for tpu reasons
135
    train_ds_batched = train_ds.batch(t_batchsize, drop_remainder=True)
136
    eval_ds_batched = eval_ds.batch(e_batchsize, drop_remainder=True)
137
    # so now we need to update the number of examples to match
138
    num_train_examples = (num_train_examples // t_batchsize) * t_batchsize
139
    num_eval_examples = (num_eval_examples // e_batchsize) * e_batchsize
140

141
    train_ds_dist = strategy.experimental_distribute_dataset(train_ds_batched)
142
    eval_ds_dist = strategy.experimental_distribute_dataset(eval_ds_batched)
143
    logging.info('Datasets set up, setting up model...')
144

145
    # instantiate optimizer and regularizer
146
    optimizer = lars.LARSOptimizer(LR.value)
147

148
    # loss has to be handled carefully when distributed on multiple cores;
149
    # specify no reduction for now and handle reduction manually in train step.
150
    loss_fn = tf.keras.losses.MeanSquaredError(
151
        reduction=tf.keras.losses.Reduction.NONE)
152
    regularizer = tf.keras.regularizers.L2(L2.value)
153

154
    # set up model
155

156
    model = models_lib.LinearLayerOverPretrainedSimclrModel(
157
        PRETRAINED_MODEL_PATH.value, optimizer, num_classes)
158

159
    logging.info('Optimizer, loss fn, model set up')
160

161
  # set up metrics and summary writers for train and eval, if required
162
  if DATA_DIR.value is not None:
163
    # metrics need to be within a strategy scope
164
    with strategy.scope():
165
      logging.info('Starting on computing y_bar and tss...')
166
      # TODO(zeef): implement a load from cache option to replace this
167
      # For testing, just hardcode the values because the computation is slow.
168
      # These are the values for the dsprites dataset.
169

170
      # y_bar, tss = metrics_lib.get_tss_for_r2(strategy, eval_ds_dist,
171
      #                                         num_classes, num_eval_examples,
172
      #                                         e_batchsize)
173

174
      y_bar = tf.constant([
175
          0.33251953, 0.33551705, 0.33196342, 0.74878615, 0.50025487,
176
          0.49955714, 0.5002258
177
      ],
178
                          dtype=tf.float32)
179
      tss = tf.constant([
180
          16363.938, 16437.312, 16350.195, 2150.051, 6469.37, 6572.2305,
181
          6550.672
182
      ],
183
                        dtype=tf.float32)
184
      logging.info('CAUTION! Hardcoded values for y_bar and tss!')
185

186
      logging.info('y_bar, tss are %s, %s', y_bar, tss)
187

188
      train_metrics = metrics_lib.DspritesTrainMetrics(DATA_DIR.value)
189
      eval_metrics = metrics_lib.DspritesEvalMetrics(DATA_DIR.value, tss)
190

191
    logging.info('Metrics set up')
192

193
  # define functions for train step, eval step, metrics update step
194
  @tf.function
195
  def train_step_loop(iterator, steps_per_loop):
196

197
    def step_fn(x):
198
      with tf.GradientTape() as tape:
199
        preds = model(x['image'])
200
        # loss is a tensor of per_example losses of size batch_size/num_replicas
201
        loss = loss_fn(x['values'], preds)
202
        loss += tf.reduce_sum(
203
            [regularizer(w) for w in model.dense_layer.trainable_weights])
204
        # pass this to metrics first so it gets an accurate count of examples
205
        if DATA_DIR.value is not None:
206
          train_metrics.update_metrics(loss, x['values'], preds)
207
        # now average the loss and then also divide by number of replicas
208
        # since the gradients from each replica are added together
209
        loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync
210
      dense_layer_weights = model.dense_layer.trainable_weights
211
      grads = tape.gradient(loss, dense_layer_weights)
212
      model.optimizer.apply_gradients(zip(grads, dense_layer_weights))
213

214
    for _ in tf.range(steps_per_loop):
215
      strategy.run(step_fn, args=(next(iterator),))
216

217
  @tf.function
218
  def eval_step_loop(iterator, steps_per_loop):
219

220
    def step_fn(x):
221
      preds = model(x['image'])
222
      loss = loss_fn(x['values'], preds)
223
      # update eval metics
224
      if DATA_DIR.value is not None:
225
        eval_metrics.update_metrics(loss, x['values'], preds)
226
      # no need to worry about scaling the loss here, since no gradients
227

228
    for _ in tf.range(steps_per_loop):
229
      strategy.run(step_fn, args=(next(iterator),))
230

231
  def metrics_update_loop(metrics_obj, global_step):
232
    for k in metrics_obj.writer_names:
233
      logging.info('Writing metric: %s', k)
234
      with metrics_obj.summary_writers[k].as_default():
235
        metrics_obj.write_metrics_to_summary(
236
            metrics_obj.metrics_dict[k], global_step=global_step)
237
        metrics_obj.summary_writers[k].flush()
238
      for metric in metrics_obj.metrics_dict[k]:
239
        metric.reset_state()
240

241
  # training loop
242

243
  num_eval_steps = num_eval_examples // e_batchsize
244
  num_train_steps = num_train_examples // t_batchsize
245
  num_train_steps_per_loop = T_STEPS_PER_LOOP.value
246
  num_train_loops_per_eval = E_FREQ.value // num_train_steps_per_loop
247

248
  train_iterator_step = 0
249
  current_step = 0
250

251
  logging.info('starting main training loop')
252
  train_iterator = iter(train_ds_dist)
253
  while current_step < TOTAL_STEPS.value:
254
    logging.info('current step %s, global step %s', current_step,
255
                 optimizer.iterations.numpy())
256
    for _ in range(num_train_loops_per_eval):
257
      # check there's enough examples left in the iterator and remake if needed
258
      # TODO(zeef): rewrite dataset creation to repeat forever?
259
      if train_iterator_step + num_train_steps_per_loop >= num_train_steps:
260
        train_iterator = iter(train_ds_dist)
261
        train_iterator_step = 0
262

263
      train_step_loop(train_iterator, num_train_steps_per_loop)
264

265

266
      metrics_update_loop(train_metrics, optimizer.iterations.numpy())
267

268
      # keep track of how far through train_iterator we are
269
      train_iterator_step += num_train_steps_per_loop
270

271
    # now run through the entire eval dataset and update eval metrics
272
    logging.info('Updating eval metrics for step %s',
273
                 optimizer.iterations.numpy())
274
    eval_iterator = iter(eval_ds_dist)
275
    eval_step_loop(eval_iterator, num_eval_steps)
276
    metrics_update_loop(eval_metrics, optimizer.iterations.numpy())
277

278
    # finally update current_step for the while loop to check
279
    current_step = optimizer.iterations.numpy()
280

281

282
if __name__ == '__main__':
283
  tf.compat.v1.enable_v2_behavior()
284
  app.run(main)
285

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.