google-research
315 строк · 11.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main file for image classification."""
17
18from absl import app
19from absl import flags
20from absl import logging
21from clu import platform
22from flax.deprecated import nn
23import jax
24import jax.numpy as jnp
25from lib import data
26from lib import models
27from lib import utils
28import lib.classification_utils as classification_lib
29from lib.layers import sample_patches
30import ml_collections
31import ml_collections.config_flags as config_flags
32import optax
33import tensorflow as tf
34
35FLAGS = flags.FLAGS
36
37config_flags.DEFINE_config_file(
38"config", None, "Training configuration.", lock_config=True)
39flags.DEFINE_string("workdir", None, "Work unit directory.")
40
41
42class ClassificationModule(nn.Module):
43"""A module that does classification."""
44
45def apply(self, x, config,
46num_classes, train = True):
47"""Creates a model definition."""
48
49if config.get("append_position_to_input", False):
50b, h, w, _ = x.shape
51coords = utils.create_grid([h, w], value_range=(0., 1.))
52x = jnp.concatenate([x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)],
53axis=-1)
54
55if config.model.lower() == "cnn":
56h = models.SimpleCNNImageClassifier(x)
57h = nn.relu(h)
58stats = None
59elif config.model.lower() == "resnet":
60smallinputs = config.get("resnet.small_inputs", False)
61blocks = config.get("resnet.blocks", [3, 4, 6, 3])
62h = models.ResNet(
63x, train=train, block_sizes=blocks, small_inputs=smallinputs)
64h = jnp.mean(h, axis=[1, 2]) # global average pool
65stats = None
66elif config.model.lower() == "resnet18":
67h = models.ResNet18(x, train=train)
68h = jnp.mean(h, axis=[1, 2]) # global average pool
69stats = None
70elif config.model.lower() == "resnet50":
71h = models.ResNet50(x, train=train)
72h = jnp.mean(h, axis=[1, 2]) # global average pool
73stats = None
74elif config.model.lower() == "ats-traffic":
75h = models.ATSFeatureNetwork(x, train=train)
76stats = None
77elif config.model.lower() == "patchnet":
78feature_network = {
79"resnet18": models.ResNet18,
80"resnet18-fourth": models.ResNet.partial(
81num_filters=16,
82block_sizes=(2, 2, 2, 2),
83block=models.BasicBlock),
84"resnet50": models.ResNet50,
85"ats-traffic": models.ATSFeatureNetwork,
86}[config.feature_network.lower()]
87
88selection_method = sample_patches.SelectionMethod(config.selection_method)
89selection_method_kwargs = {}
90if selection_method is sample_patches.SelectionMethod.SINKHORN_TOPK:
91selection_method_kwargs = config.sinkhorn_topk_kwargs
92if selection_method is sample_patches.SelectionMethod.PERTURBED_TOPK:
93selection_method_kwargs = config.perturbed_topk_kwargs
94
95h, stats = sample_patches.PatchNet(
96x,
97patch_size=config.patch_size,
98k=config.k,
99downscale=config.downscale,
100scorer_has_se=config.get("scorer_has_se", False),
101selection_method=config.selection_method,
102selection_method_kwargs=selection_method_kwargs,
103selection_method_inference=config.get("selection_method_inference",
104None),
105normalization_str=config.normalization_str,
106aggregation_method=config.aggregation_method,
107aggregation_method_kwargs=config.get("aggregation_method_kwargs", {}),
108append_position_to_input=config.get("append_position_to_input",
109False),
110feature_network=feature_network,
111use_iterative_extraction=config.use_iterative_extraction,
112hard_topk_probability=config.get("hard_topk_probability", 0.),
113random_patch_probability=config.get("random_patch_probability", 0.),
114train=train)
115stats["x"] = x
116else:
117raise RuntimeError(
118"Unknown classification model type: %s" % config.model.lower())
119out = nn.Dense(h, num_classes, name="final")
120return out, stats
121
122
123def create_optimizer(config):
124"""Creates the optimizer associated to a config."""
125ops = []
126
127# Gradient clipping either by norm `gradient_norm_clip` or by absolute value
128# `gradient_value_clip`.
129if "gradient_clip" in config:
130raise ValueError("'gradient_clip' is deprecated, please use "
131"'gradient_norm_clip'.")
132assert not ("gradient_norm_clip" in config and
133"gradient_value_clip" in config), (
134"Gradient clipping by norm and by value are exclusive.")
135
136if "gradient_norm_clip" in config:
137ops.append(optax.clip_by_global_norm(config.gradient_norm_clip))
138if "gradient_value_clip" in config:
139ops.append(optax.clip(config.gradient_value_clip))
140
141# Define the learning rate schedule.
142schedule_fn = utils.get_optax_schedule_fn(
143warmup_ratio=config.get("warmup_ratio", 0.),
144num_train_steps=config.num_train_steps,
145decay=config.get("learning_rate_step_decay", 1.0),
146decay_at_steps=config.get("learning_rate_decay_at_steps", []),
147cosine_decay_schedule=config.get("cosine_decay", False))
148
149schedule_ops = [optax.scale_by_schedule(schedule_fn)]
150
151# Scale some parameters matching a regex by a multiplier. Config field
152# `scaling_by_regex` is a list of pairs (regex: str, multiplier: float).
153scaling_by_regex = config.get("scaling_learning_rate_by_regex", [])
154for regex, multiplier in scaling_by_regex:
155logging.info("Learning rate is scaled by %f for parameters matching '%s'",
156multiplier, regex)
157schedule_ops.append(utils.scale_selected_parameters(regex, multiplier))
158schedule_optimizer = optax.chain(*schedule_ops)
159
160if config.optimizer.lower() == "adam":
161optimizer = optax.adam(config.learning_rate)
162ops.append(optimizer)
163ops.append(schedule_optimizer)
164elif config.optimizer.lower() == "sgd":
165ops.append(schedule_optimizer)
166optimizer = optax.sgd(config.learning_rate, momentum=config.momentum)
167ops.append(optimizer)
168else:
169raise NotImplementedError("Invalid optimizer: {}".format(
170config.optimizer))
171
172if "weight_decay" in config and config.weight_decay > 0.:
173ops.append(utils.decoupled_weight_decay(
174decay=config.weight_decay, step_size_fn=schedule_fn))
175
176# Freeze parameters that match the given regexes (if any).
177freeze_weights_regexes = config.get("freeze_weights_regex", []) or []
178if isinstance(freeze_weights_regexes, str):
179freeze_weights_regexes = [freeze_weights_regexes]
180for reg in freeze_weights_regexes:
181ops.append(utils.freeze(reg))
182
183return optax.chain(*ops)
184
185
186def train_and_evaluate(config, workdir):
187"""Runs a training and evaluation loop.
188
189Args:
190config: Configuration to use.
191workdir: Working directory for checkpoints and TF summaries. If this
192contains checkpoint, training will be resumed from the latest checkpoint.
193
194Returns:
195Training state.
196"""
197rng = jax.random.PRNGKey(config.seed)
198rng, data_rng = jax.random.split(rng)
199
200# Make sure config defines num_epochs and num_train_steps appropriately.
201utils.check_epochs_and_steps(config)
202
203train_preprocessing_fn, eval_preprocessing_fn = data.parse_preprocessing_strings(
204config.get("train_preprocess_str", ""),
205config.get("eval_preprocess_str", ""))
206
207assert config.batch_size % jax.local_device_count() == 0, (
208f"Batch size ({config.batch_size}) should be divisible by number of "
209f"devices ({jax.local_device_count()}).")
210
211per_device_batch_size = config.batch_size // jax.local_device_count()
212train_ds, eval_ds, num_classes = data.get_dataset(
213config.dataset,
214per_device_batch_size,
215data_rng,
216train_preprocessing_fn=train_preprocessing_fn,
217eval_preprocessing_fn=eval_preprocessing_fn,
218**config.get("data", {}))
219
220module = ClassificationModule.partial(config=config, num_classes=num_classes)
221
222optimizer = create_optimizer(config)
223
224# Enables relevant statistics aggregator.
225stats_aggregators = []
226
227train_metrics_dict = {
228"train_loss": classification_lib.cross_entropy,
229"train_accuracy": classification_lib.accuracy
230}
231eval_metrics_dict = {
232"eval_loss": classification_lib.cross_entropy,
233"eval_accuracy": classification_lib.accuracy
234}
235loss_fn = classification_lib.cross_entropy
236
237def loss_from_stats(field, multiplier):
238return lambda logits, labels, stats: multiplier * stats[field]
239
240# Add some regularizer to the loss if needed.
241if (config.model == "patchnet" and
242config.selection_method not in [sample_patches.SelectionMethod.HARD_TOPK,
243sample_patches.SelectionMethod.RANDOM]):
244entropy_regularizer = config.get("entropy_regularizer", 0.)
245entropy_before_normalization = config.get("entropy_before_normalization",
246False)
247
248stat_field = "entropy"
249if entropy_before_normalization:
250stat_field = "entropy_before_normalization"
251
252if entropy_regularizer != 0.:
253logging.info("Add entropy regularizer %s normalization to the loss %f.",
254"before" if entropy_before_normalization else "after",
255entropy_regularizer)
256loss_fn = [loss_fn, loss_from_stats(stat_field, entropy_regularizer)]
257
258def entropy_aggregator(stats):
259return {stat_field: stats[stat_field],}
260stats_aggregators.append(entropy_aggregator)
261
262def add_image_prefix(image_aggregator):
263def aggregator(stats):
264d = image_aggregator(stats)
265return {f"image_{k}": v for k, v in d.items()}
266return aggregator
267
268if config.model == "patchnet" and config.get("log_images", True):
269@add_image_prefix
270def plot_patches(stats):
271keys = ["extracted_patches", "x", "scores"]
272return {k: stats[k] for k in keys if k in stats}
273
274stats_aggregators.append(plot_patches)
275
276state = classification_lib.training_loop(
277module=module,
278rng=rng,
279train_ds=train_ds,
280eval_ds=eval_ds,
281loss_fn=loss_fn,
282optimizer=optimizer,
283train_metrics_dict=train_metrics_dict,
284eval_metrics_dict=eval_metrics_dict,
285stats_aggregators=stats_aggregators,
286config=config,
287workdir=workdir)
288return state
289
290
291def main(argv):
292del argv
293
294# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
295# it unavailable to JAX.
296tf.config.experimental.set_visible_devices([], "GPU")
297
298
299logging.info("JAX host: %d / %d", jax.host_id(), jax.host_count())
300logging.info("JAX devices: %r", jax.devices())
301
302# Add a note so that we can tell which task is which JAX host.
303# (Borg task 0 is not guaranteed to be host 0)
304platform.work_unit().set_task_status(
305f"host_id: {jax.host_id()}, host_count: {jax.host_count()}")
306platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
307FLAGS.workdir, "workdir")
308
309state = train_and_evaluate(FLAGS.config, FLAGS.workdir)
310del state
311
312
313if __name__ == "__main__":
314flags.mark_flags_as_required(["config", "workdir"])
315app.run(main)
316