google-research

Форк
0
/
image_classification.py 
315 строк · 11.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main file for image classification."""
17

18
from absl import app
19
from absl import flags
20
from absl import logging
21
from clu import platform
22
from flax.deprecated import nn
23
import jax
24
import jax.numpy as jnp
25
from lib import data
26
from lib import models
27
from lib import utils
28
import lib.classification_utils as classification_lib
29
from lib.layers import sample_patches
30
import ml_collections
31
import ml_collections.config_flags as config_flags
32
import optax
33
import tensorflow as tf
34

35
FLAGS = flags.FLAGS
36

37
config_flags.DEFINE_config_file(
38
    "config", None, "Training configuration.", lock_config=True)
39
flags.DEFINE_string("workdir", None, "Work unit directory.")
40

41

42
class ClassificationModule(nn.Module):
43
  """A module that does classification."""
44

45
  def apply(self, x, config,
46
            num_classes, train = True):
47
    """Creates a model definition."""
48

49
    if config.get("append_position_to_input", False):
50
      b, h, w, _ = x.shape
51
      coords = utils.create_grid([h, w], value_range=(0., 1.))
52
      x = jnp.concatenate([x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)],
53
                          axis=-1)
54

55
    if config.model.lower() == "cnn":
56
      h = models.SimpleCNNImageClassifier(x)
57
      h = nn.relu(h)
58
      stats = None
59
    elif config.model.lower() == "resnet":
60
      smallinputs = config.get("resnet.small_inputs", False)
61
      blocks = config.get("resnet.blocks", [3, 4, 6, 3])
62
      h = models.ResNet(
63
          x, train=train, block_sizes=blocks, small_inputs=smallinputs)
64
      h = jnp.mean(h, axis=[1, 2])   # global average pool
65
      stats = None
66
    elif config.model.lower() == "resnet18":
67
      h = models.ResNet18(x, train=train)
68
      h = jnp.mean(h, axis=[1, 2])   # global average pool
69
      stats = None
70
    elif config.model.lower() == "resnet50":
71
      h = models.ResNet50(x, train=train)
72
      h = jnp.mean(h, axis=[1, 2])   # global average pool
73
      stats = None
74
    elif config.model.lower() == "ats-traffic":
75
      h = models.ATSFeatureNetwork(x, train=train)
76
      stats = None
77
    elif config.model.lower() == "patchnet":
78
      feature_network = {
79
          "resnet18": models.ResNet18,
80
          "resnet18-fourth": models.ResNet.partial(
81
              num_filters=16,
82
              block_sizes=(2, 2, 2, 2),
83
              block=models.BasicBlock),
84
          "resnet50": models.ResNet50,
85
          "ats-traffic": models.ATSFeatureNetwork,
86
      }[config.feature_network.lower()]
87

88
      selection_method = sample_patches.SelectionMethod(config.selection_method)
89
      selection_method_kwargs = {}
90
      if selection_method is sample_patches.SelectionMethod.SINKHORN_TOPK:
91
        selection_method_kwargs = config.sinkhorn_topk_kwargs
92
      if selection_method is sample_patches.SelectionMethod.PERTURBED_TOPK:
93
        selection_method_kwargs = config.perturbed_topk_kwargs
94

95
      h, stats = sample_patches.PatchNet(
96
          x,
97
          patch_size=config.patch_size,
98
          k=config.k,
99
          downscale=config.downscale,
100
          scorer_has_se=config.get("scorer_has_se", False),
101
          selection_method=config.selection_method,
102
          selection_method_kwargs=selection_method_kwargs,
103
          selection_method_inference=config.get("selection_method_inference",
104
                                                None),
105
          normalization_str=config.normalization_str,
106
          aggregation_method=config.aggregation_method,
107
          aggregation_method_kwargs=config.get("aggregation_method_kwargs", {}),
108
          append_position_to_input=config.get("append_position_to_input",
109
                                              False),
110
          feature_network=feature_network,
111
          use_iterative_extraction=config.use_iterative_extraction,
112
          hard_topk_probability=config.get("hard_topk_probability", 0.),
113
          random_patch_probability=config.get("random_patch_probability", 0.),
114
          train=train)
115
      stats["x"] = x
116
    else:
117
      raise RuntimeError(
118
          "Unknown classification model type: %s" % config.model.lower())
119
    out = nn.Dense(h, num_classes, name="final")
120
    return out, stats
121

122

123
def create_optimizer(config):
124
  """Creates the optimizer associated to a config."""
125
  ops = []
126

127
  # Gradient clipping either by norm `gradient_norm_clip` or by absolute value
128
  # `gradient_value_clip`.
129
  if "gradient_clip" in config:
130
    raise ValueError("'gradient_clip' is deprecated, please use "
131
                     "'gradient_norm_clip'.")
132
  assert not ("gradient_norm_clip" in config and
133
              "gradient_value_clip" in config), (
134
                  "Gradient clipping by norm and by value are exclusive.")
135

136
  if "gradient_norm_clip" in config:
137
    ops.append(optax.clip_by_global_norm(config.gradient_norm_clip))
138
  if "gradient_value_clip" in config:
139
    ops.append(optax.clip(config.gradient_value_clip))
140

141
  # Define the learning rate schedule.
142
  schedule_fn = utils.get_optax_schedule_fn(
143
      warmup_ratio=config.get("warmup_ratio", 0.),
144
      num_train_steps=config.num_train_steps,
145
      decay=config.get("learning_rate_step_decay", 1.0),
146
      decay_at_steps=config.get("learning_rate_decay_at_steps", []),
147
      cosine_decay_schedule=config.get("cosine_decay", False))
148

149
  schedule_ops = [optax.scale_by_schedule(schedule_fn)]
150

151
  # Scale some parameters matching a regex by a multiplier. Config field
152
  # `scaling_by_regex` is a list of pairs (regex: str, multiplier: float).
153
  scaling_by_regex = config.get("scaling_learning_rate_by_regex", [])
154
  for regex, multiplier in scaling_by_regex:
155
    logging.info("Learning rate is scaled by %f for parameters matching '%s'",
156
                 multiplier, regex)
157
    schedule_ops.append(utils.scale_selected_parameters(regex, multiplier))
158
  schedule_optimizer = optax.chain(*schedule_ops)
159

160
  if config.optimizer.lower() == "adam":
161
    optimizer = optax.adam(config.learning_rate)
162
    ops.append(optimizer)
163
    ops.append(schedule_optimizer)
164
  elif config.optimizer.lower() == "sgd":
165
    ops.append(schedule_optimizer)
166
    optimizer = optax.sgd(config.learning_rate, momentum=config.momentum)
167
    ops.append(optimizer)
168
  else:
169
    raise NotImplementedError("Invalid optimizer: {}".format(
170
        config.optimizer))
171

172
  if "weight_decay" in config and config.weight_decay > 0.:
173
    ops.append(utils.decoupled_weight_decay(
174
        decay=config.weight_decay, step_size_fn=schedule_fn))
175

176
  # Freeze parameters that match the given regexes (if any).
177
  freeze_weights_regexes = config.get("freeze_weights_regex", []) or []
178
  if isinstance(freeze_weights_regexes, str):
179
    freeze_weights_regexes = [freeze_weights_regexes]
180
  for reg in freeze_weights_regexes:
181
    ops.append(utils.freeze(reg))
182

183
  return optax.chain(*ops)
184

185

186
def train_and_evaluate(config, workdir):
187
  """Runs a training and evaluation loop.
188

189
  Args:
190
    config: Configuration to use.
191
    workdir: Working directory for checkpoints and TF summaries. If this
192
      contains checkpoint, training will be resumed from the latest checkpoint.
193

194
  Returns:
195
    Training state.
196
  """
197
  rng = jax.random.PRNGKey(config.seed)
198
  rng, data_rng = jax.random.split(rng)
199

200
  # Make sure config defines num_epochs and num_train_steps appropriately.
201
  utils.check_epochs_and_steps(config)
202

203
  train_preprocessing_fn, eval_preprocessing_fn = data.parse_preprocessing_strings(
204
      config.get("train_preprocess_str", ""),
205
      config.get("eval_preprocess_str", ""))
206

207
  assert config.batch_size % jax.local_device_count() == 0, (
208
      f"Batch size ({config.batch_size}) should be divisible by number of "
209
      f"devices ({jax.local_device_count()}).")
210

211
  per_device_batch_size = config.batch_size // jax.local_device_count()
212
  train_ds, eval_ds, num_classes = data.get_dataset(
213
      config.dataset,
214
      per_device_batch_size,
215
      data_rng,
216
      train_preprocessing_fn=train_preprocessing_fn,
217
      eval_preprocessing_fn=eval_preprocessing_fn,
218
      **config.get("data", {}))
219

220
  module = ClassificationModule.partial(config=config, num_classes=num_classes)
221

222
  optimizer = create_optimizer(config)
223

224
  # Enables relevant statistics aggregator.
225
  stats_aggregators = []
226

227
  train_metrics_dict = {
228
      "train_loss": classification_lib.cross_entropy,
229
      "train_accuracy": classification_lib.accuracy
230
  }
231
  eval_metrics_dict = {
232
      "eval_loss": classification_lib.cross_entropy,
233
      "eval_accuracy": classification_lib.accuracy
234
  }
235
  loss_fn = classification_lib.cross_entropy
236

237
  def loss_from_stats(field, multiplier):
238
    return lambda logits, labels, stats: multiplier * stats[field]
239

240
  # Add some regularizer to the loss if needed.
241
  if (config.model == "patchnet" and
242
      config.selection_method not in [sample_patches.SelectionMethod.HARD_TOPK,
243
                                      sample_patches.SelectionMethod.RANDOM]):
244
    entropy_regularizer = config.get("entropy_regularizer", 0.)
245
    entropy_before_normalization = config.get("entropy_before_normalization",
246
                                              False)
247

248
    stat_field = "entropy"
249
    if entropy_before_normalization:
250
      stat_field = "entropy_before_normalization"
251

252
    if entropy_regularizer != 0.:
253
      logging.info("Add entropy regularizer %s normalization to the loss %f.",
254
                   "before" if entropy_before_normalization else "after",
255
                   entropy_regularizer)
256
      loss_fn = [loss_fn, loss_from_stats(stat_field, entropy_regularizer)]
257

258
    def entropy_aggregator(stats):
259
      return {stat_field: stats[stat_field],}
260
    stats_aggregators.append(entropy_aggregator)
261

262
  def add_image_prefix(image_aggregator):
263
    def aggregator(stats):
264
      d = image_aggregator(stats)
265
      return {f"image_{k}": v for k, v in d.items()}
266
    return aggregator
267

268
  if config.model == "patchnet" and config.get("log_images", True):
269
    @add_image_prefix
270
    def plot_patches(stats):
271
      keys = ["extracted_patches", "x", "scores"]
272
      return {k: stats[k] for k in keys if k in stats}
273

274
    stats_aggregators.append(plot_patches)
275

276
  state = classification_lib.training_loop(
277
      module=module,
278
      rng=rng,
279
      train_ds=train_ds,
280
      eval_ds=eval_ds,
281
      loss_fn=loss_fn,
282
      optimizer=optimizer,
283
      train_metrics_dict=train_metrics_dict,
284
      eval_metrics_dict=eval_metrics_dict,
285
      stats_aggregators=stats_aggregators,
286
      config=config,
287
      workdir=workdir)
288
  return state
289

290

291
def main(argv):
292
  del argv
293

294
  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
295
  # it unavailable to JAX.
296
  tf.config.experimental.set_visible_devices([], "GPU")
297

298

299
  logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
300
  logging.info("JAX devices: %r", jax.devices())
301

302
  # Add a note so that we can tell which task is which JAX host.
303
  # (Borg task 0 is not guaranteed to be host 0)
304
  platform.work_unit().set_task_status(
305
      f"host_id: {jax.host_id()}, host_count: {jax.host_count()}")
306
  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
307
                                       FLAGS.workdir, "workdir")
308

309
  state = train_and_evaluate(FLAGS.config, FLAGS.workdir)
310
  del state
311

312

313
if __name__ == "__main__":
314
  flags.mark_flags_as_required(["config", "workdir"])
315
  app.run(main)
316

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.