google-research
590 строк · 19.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""NTS-Net adapted for perturbed top-k.
17
18Based on the original PyTorch code
19https://github.com/yangze0930/NTS-Net/blob/master/core/model.py
20"""
21
22import enum23import functools24import math25from typing import List, Tuple26
27from absl import app28from absl import flags29from absl import logging30import chex31from clu import platform32import einops33from flax.deprecated import nn34import jax35import jax.numpy as jnp36import ml_collections37import ml_collections.config_flags as config_flags38from off_the_grid.lib import data39from off_the_grid.lib import models40from off_the_grid.lib import utils41import off_the_grid.lib.classification_utils as classification_lib42from off_the_grid.lib.layers import sample_patches43from off_the_grid.lib.layers import transformer44import optax45import tensorflow as tf46
47
48FLAGS = flags.FLAGS49
50config_flags.DEFINE_config_file(51"config", None, "Training configuration.", lock_config=True)52flags.DEFINE_string("workdir", None, "Work unit directory.")53NUM_CLASSES = 20054
55ANCHORS_SETTINGS = (56dict(57layer="p3",58stride=32,59size=48,60scale=[2**(1. / 3.), 2**(2. / 3.)],61aspect_ratio=[0.667, 1, 1.5]), # Anchors 0-562dict(63layer="p4",64stride=64,65size=96,66scale=[2**(1. / 3.), 2**(2. / 3.)],67aspect_ratio=[0.667, 1, 1.5]), # Anchors 6-1168dict(69layer="p5",70stride=128,71size=192,72scale=[1, 2**(1. / 3.), 2**(2. / 3.)],73aspect_ratio=[0.667, 1, 1.5]), # Anchors 12-2074)
75
76
77class Communication(str, enum.Enum):78NONE = "none"79SQUEEZE_EXCITE_D = "squeeze_excite_d"80SQUEEZE_EXCITE_X = "squeeze_excite_x"81TRANSFORMER = "transformer"82
83
84def zeroone(scores, x_min, x_max):85"""Normalize values to lie between [0, 1]."""86return [(x - x_min) / (x_max - x_min + 1e-5) for x in scores]87
88
89class ProposalNet(nn.Module):90"""FPN inspired scorer module."""91
92def apply(self, x,93communication = Communication.NONE,94train = True):95"""Forward pass."""96batch_size = x.shape[0]97
98if communication is Communication.SQUEEZE_EXCITE_X:99x = sample_patches.SqueezeExciteLayer(x)100# end if squeeze excite x101
102d1 = nn.relu(nn.Conv(103x, 128, kernel_size=(3, 3), strides=(1, 1), bias=True, name="down1"))104d2 = nn.relu(nn.Conv(105d1, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down2"))106d3 = nn.relu(nn.Conv(107d2, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down3"))108
109if communication is Communication.SQUEEZE_EXCITE_D:110d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")111d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")112d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")113
114nd1 = d1_flatten.shape[1]115nd2 = d2_flatten.shape[1]116
117d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1)118
119num_channels = d_together.shape[-1]120y = d_together.mean(axis=1)121y = nn.Dense(y, features=num_channels // 4, bias=False)122y = nn.relu(y)123y = nn.Dense(y, features=num_channels, bias=False)124y = nn.sigmoid(y)125
126d_together = d_together * y[:, None, :]127
128# split and reshape129d1 = d_together[:, :nd1].reshape(d1.shape)130d2 = d_together[:, nd1:nd1+nd2].reshape(d2.shape)131d3 = d_together[:, nd1+nd2:].reshape(d3.shape)132
133elif communication is Communication.TRANSFORMER:134d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")135d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")136d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")137
138nd1 = d1_flatten.shape[1]139nd2 = d2_flatten.shape[1]140
141d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1)142
143positional_encodings = self.param(144"scale_ratio_position_encodings",145shape=(1,) + d_together.shape[1:],146initializer=jax.nn.initializers.normal(1. / d_together.shape[-1]))147d_together = transformer.Transformer(148d_together + positional_encodings,149num_layers=2,150num_heads=8,151is_training=train)152
153# split and reshape154d1 = d_together[:, :nd1].reshape(d1.shape)155d2 = d_together[:, nd1:nd1+nd2].reshape(d2.shape)156d3 = d_together[:, nd1+nd2:].reshape(d3.shape)157
158t1 = nn.Conv(159d1, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy1")160t2 = nn.Conv(161d2, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy2")162t3 = nn.Conv(163d3, 9, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy3")164
165raw_scores = (jnp.split(t1, 6, axis=-1) +166jnp.split(t2, 6, axis=-1) +167jnp.split(t3, 9, axis=-1))168
169# The following is for normalization.170t = jnp.concatenate((jnp.reshape(t1, [batch_size, -1]),171jnp.reshape(t2, [batch_size, -1]),172jnp.reshape(t3, [batch_size, -1])), axis=1)173t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1])174t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1])175normalized_scores = zeroone(raw_scores, t_min, t_max)176
177stats = {178"scores": normalized_scores,179"raw_scores": t,180}181# removes the split dimension. scores are now b x h' x w' shaped182normalized_scores = [s.squeeze(-1) for s in normalized_scores]183
184return normalized_scores, stats185
186
187def extract_weighted_patches(x,188weights,189kernel,190stride,191padding):192"""Weighted average of patches using jax.lax.scan."""193logging.info("recompiling for kernel=%s and stride=%s and padding=%s", kernel,194stride, padding)195x = jnp.pad(x, ((0, 0),196(padding[0], padding[0] + kernel[0]),197(padding[1], padding[1] + kernel[1]),198(0, 0)))199batch_size, _, _, channels = x.shape200_, k, weights_h, weights_w = weights.shape201
202def accumulate_patches(acc, index_i_j):203i, j = index_i_j204patch = jax.lax.dynamic_slice(205x,206(0, i * stride[0], j * stride[1], 0),207(batch_size, kernel[0], kernel[1], channels))208weight = weights[:, :, i, j]209
210weighted_patch = jnp.einsum("bk, bijc -> bkijc", weight, patch)211acc += weighted_patch212return acc, None213
214indices = jnp.stack(215jnp.meshgrid(jnp.arange(weights_h), jnp.arange(weights_w), indexing="ij"),216axis=-1)217indices = indices.reshape((-1, 2))218
219init_patches = jnp.zeros((batch_size, k, kernel[0], kernel[1], channels))220patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices)221
222return patches223
224
225def weighted_anchor_aggregator(x, weights):226"""Given a tensor of weights per anchor computes the weighted average."""227counter = 0228all_sub_aggregates = []229
230for anchor_info in ANCHORS_SETTINGS:231stride = anchor_info["stride"]232size = anchor_info["size"]233for scale in anchor_info["scale"]:234for aspect_ratio in anchor_info["aspect_ratio"]:235kernel_size = (236int(size * scale / float(aspect_ratio) ** 0.5),237int(size * scale * float(aspect_ratio) ** 0.5))238padding = (239math.ceil((kernel_size[0] - stride) / 2.),240math.ceil((kernel_size[1] - stride) / 2.))241aggregate = extract_weighted_patches(242x, weights[counter], kernel_size, (stride, stride), padding)243aggregate = jnp.reshape(aggregate,244[-1, kernel_size[0], kernel_size[1], 3])245aggregate_224 = jax.image.resize(aggregate,246[aggregate.shape[0], 224, 224, 3],247"bilinear")248all_sub_aggregates.append(aggregate_224)249counter += 1250
251return jnp.sum(jnp.stack(all_sub_aggregates, axis=0), axis=0)252
253
254class AttentionNet(nn.Module):255"""The complete NTS-Net model using perturbed top-k."""256
257def apply(self,258x,259config,260num_classes,261train = True):262"""Creates a model definition."""263b, c = x.shape[0], x.shape[3]264k = config.k265sigma = config.ptopk_sigma266num_samples = config.ptopk_num_samples267
268sigma *= self.state("sigma_mutiplier", shape=(),269initializer=nn.initializers.ones).value270
271stats = {"x": x, "sigma": sigma}272
273feature_extractor = models.ResNet50.shared(train=train, name="ResNet_0")274
275rpn_feature = feature_extractor(x)276rpn_scores, rpn_stats = ProposalNet(277jax.lax.stop_gradient(rpn_feature),278communication=Communication(config.communication),279train=train)280stats.update(rpn_stats)281
282# rpn_scores are a list of score images. We keep track of the structure283# because it is used in the aggregation step later-on.284rpn_scores_shapes = [s.shape for s in rpn_scores]285rpn_scores_flat = jnp.concatenate(286[jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1)287top_k_indicators = sample_patches.select_patches_perturbed_topk(288rpn_scores_flat,289k=k,290sigma=sigma,291num_samples=num_samples)292top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1])293offset = 0294weights = []295for sh in rpn_scores_shapes:296cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]]297cur = jnp.reshape(cur, [b, k, sh[1], sh[2]])298weights.append(cur)299offset += sh[1] * sh[2]300chex.assert_equal(offset, top_k_indicators.shape[-1])301
302part_imgs = weighted_anchor_aggregator(x, weights)303chex.assert_shape(part_imgs, (b * k, 224, 224, c))304stats["part_imgs"] = jnp.reshape(part_imgs, [b, k*224, 224, c])305
306part_features = feature_extractor(part_imgs)307part_features = jnp.mean(part_features, axis=[1, 2]) # GAP the spatial dims308
309part_features = nn.dropout( # features from parts310jnp.reshape(part_features, [b * k, 2048]),3110.5,312deterministic=not train,313rng=nn.make_rng())314features = nn.dropout( # features from whole image315jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]),3160.5,317deterministic=not train,318rng=nn.make_rng())319
320# Mean pool all part features, add it to features and predict logits.321concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]),322axis=1) + features323concat_logits = nn.Dense(concat_out, num_classes)324raw_logits = nn.Dense(features, num_classes)325part_logits = jnp.reshape(nn.Dense(part_features, num_classes), [b, k, -1])326
327all_logits = {328"raw_logits": raw_logits,329"concat_logits": concat_logits,330"part_logits": part_logits,331}332# add entropy into it for entropy regularization.333stats["rpn_scores_entropy"] = jax.scipy.special.entr(334jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0)335return all_logits, stats336
337
338def create_optimizer(config):339"""Creates the optimizer associated to a config."""340ops = []341
342# Gradient clipping either by norm `gradient_norm_clip` or by absolute value343# `gradient_value_clip`.344if "gradient_clip" in config:345raise ValueError("'gradient_clip' is deprecated, please use "346"'gradient_norm_clip'.")347assert not ("gradient_norm_clip" in config and348"gradient_value_clip" in config), (349"Gradient clipping by norm and by value are exclusive.")350
351if "gradient_norm_clip" in config:352ops.append(optax.clip_by_global_norm(config.gradient_norm_clip))353if "gradient_value_clip" in config:354ops.append(optax.clip(config.gradient_value_clip))355
356# Define the learning rate schedule.357schedule_fn = utils.get_optax_schedule_fn(358warmup_ratio=config.get("warmup_ratio", 0.),359num_train_steps=config.num_train_steps,360decay=config.get("learning_rate_step_decay", 1.0),361decay_at_steps=config.get("learning_rate_decay_at_steps", []),362cosine_decay_schedule=config.get("cosine_decay", False))363
364schedule_ops = [optax.scale_by_schedule(schedule_fn)]365
366# Scale some parameters matching a regex by a multiplier. Config field367# `scaling_by_regex` is a list of pairs (regex: str, multiplier: float).368scaling_by_regex = config.get("scaling_learning_rate_by_regex", [])369for regex, multiplier in scaling_by_regex:370logging.info("Learning rate is scaled by %f for parameters matching '%s'",371multiplier, regex)372schedule_ops.append(utils.scale_selected_parameters(regex, multiplier))373schedule_optimizer = optax.chain(*schedule_ops)374
375if "weight_decay_coupled" in config and config.weight_decay_coupled > 0.:376# it calls decoupled weight decay before applying optimizer which is377# coupled weight decay. :D378ops.append(utils.decoupled_weight_decay(379decay=config.weight_decay_coupled,380step_size_fn=lambda x: jnp.ones([], dtype=jnp.float32)))381
382if config.optimizer.lower() == "adam":383optimizer = optax.adam(config.learning_rate)384ops.append(optimizer)385ops.append(schedule_optimizer)386elif config.optimizer.lower() == "sgd":387ops.append(schedule_optimizer)388optimizer = optax.sgd(config.learning_rate, momentum=config.momentum)389ops.append(optimizer)390else:391raise NotImplementedError("Invalid optimizer: {}".format(392config.optimizer))393
394if "weight_decay" in config and config.weight_decay > 0.:395ops.append(utils.decoupled_weight_decay(396decay=config.weight_decay, step_size_fn=schedule_fn))397
398# Freeze parameters that match the given regexes (if any).399freeze_weights_regexes = config.get("freeze_weights_regex", []) or []400if isinstance(freeze_weights_regexes, str):401freeze_weights_regexes = [freeze_weights_regexes]402for reg in freeze_weights_regexes:403ops.append(utils.freeze(reg))404
405return optax.chain(*ops)406
407
408def cross_entropy(logits, labels):409"""Basic corss entropy loss."""410logp = jax.nn.log_softmax(logits)411loglik = jnp.take_along_axis(logp, labels[:, None], axis=1)412return -jnp.mean(loglik)413
414
415def ntsnet_loss(logits_dict, labels, stats, config):416"""Customized cross entropy loss for dictionary of logits."""417raw_logits = logits_dict["raw_logits"]418concat_logits = logits_dict["concat_logits"]419part_logits = logits_dict["part_logits"]420
421raw_loss = cross_entropy(raw_logits, labels)422concat_loss = cross_entropy(concat_logits, labels)423
424k = part_logits.shape[1]425num_classes = part_logits.shape[2]426labels_per_part = jnp.tile(jnp.expand_dims(labels, axis=1), [1, k])427part_loss = cross_entropy(428jnp.reshape(part_logits, [-1, num_classes]),429jnp.reshape(labels_per_part, [-1,]))430
431reg = config.entropy_regularizer * rpn_scores_entropy(432logits_dict, labels, stats)433
434return raw_loss + concat_loss + part_loss + reg435
436
437def accuracy(logits_dict, labels, stats):438"""Customized accuracy metric for dictionary of logits."""439del stats440logits = logits_dict["concat_logits"]441predictions = jnp.argmax(logits, axis=-1)442return jnp.mean(predictions == labels)443
444
445def cross_entropy_raw_logits(logits_dict, labels, stats):446"""Customized cross entropy loss for dictionary of logits."""447del stats448return cross_entropy(logits_dict["raw_logits"], labels)449
450
451def cross_entropy_concat_logits(logits_dict, labels, stats):452"""Customized cross entropy loss for dictionary of logits."""453del stats454return cross_entropy(logits_dict["concat_logits"], labels)455
456
457def cross_entropy_part_logits(logits_dict, labels, stats):458"""Customized cross entropy loss for dictionary of logits."""459del stats460part_logits = logits_dict["part_logits"]461k = part_logits.shape[1]462num_classes = part_logits.shape[2]463labels_per_part = jnp.tile(jnp.expand_dims(labels, axis=1), [1, k])464part_loss = cross_entropy(465jnp.reshape(part_logits, [-1, num_classes]),466jnp.reshape(labels_per_part, [-1,]))467return part_loss468
469
470def rpn_scores_entropy(logits_dict, labels, stats):471"""Entropy."""472del logits_dict473del labels474return stats["rpn_scores_entropy"]475
476
477def train_and_evaluate(config, workdir):478"""Runs a training and evaluation loop.479
480Args:
481config: Configuration to use.
482workdir: Working directory for checkpoints and TF summaries. If this
483contains checkpoint training will be resumed from the latest checkpoint.
484
485Returns:
486Training state.
487"""
488rng = jax.random.PRNGKey(config.seed)489rng, data_rng = jax.random.split(rng)490
491# Make sure config defines num_epochs and num_train_steps appropriately.492utils.check_epochs_and_steps(config)493
494# Check that perturbed-topk is selection method.495assert config.selection_method == "perturbed-topk", (496"ntsnet only supports perturbed-topk as selection method. Got: {}".format(497config.selection_method))498
499train_preprocessing_fn, eval_preprocessing_fn = data.parse_preprocessing_strings(500config.get("train_preprocess_str", ""),501config.get("eval_preprocess_str", ""))502
503assert config.batch_size % jax.local_device_count() == 0, (504f"Batch size ({config.batch_size}) should be divisible by number of "505f"devices ({jax.local_device_count()}).")506
507per_device_batch_size = config.batch_size // jax.local_device_count()508train_ds, eval_ds, num_classes = data.get_dataset(509config.dataset,510per_device_batch_size,511data_rng,512train_preprocessing_fn=train_preprocessing_fn,513eval_preprocessing_fn=eval_preprocessing_fn,514**config.get("data", {}))515
516module = AttentionNet.partial(config=config, num_classes=num_classes)517
518optimizer = create_optimizer(config)519
520loss_fn = functools.partial(ntsnet_loss, config=config)521train_metrics_dict = {522"train_loss": loss_fn,523"train_loss_raw": cross_entropy_raw_logits,524"train_loss_concat": cross_entropy_concat_logits,525"train_loss_part": cross_entropy_part_logits,526"train_accuracy": accuracy,527"train_rpn_scores_entropy": rpn_scores_entropy,528}529eval_metrics_dict = {530"eval_loss": loss_fn,531"eval_loss_raw": cross_entropy_raw_logits,532"eval_loss_concat": cross_entropy_concat_logits,533"eval_loss_part": cross_entropy_part_logits,534"eval_accuracy": accuracy,535"eval_rpn_scores_entropy": rpn_scores_entropy,536}537
538# Enables relevant statistics aggregator.539stats_aggregators = []540
541def add_image_prefix(image_aggregator):542def aggregator(stats):543d = image_aggregator(stats)544return {f"image_{k}": v for k, v in d.items()}545return aggregator546
547if config.get("log_images", True):548@add_image_prefix549def plot_patches(stats):550d = {551"part_imgs": (stats["part_imgs"] + 1.0) / 2.0,552"x": (stats["x"] + 1.0) / 2.0553}554for i, sc in enumerate(stats["scores"]):555d[f"scores_{i}"] = sc556return d557
558stats_aggregators.append(plot_patches)559
560stats_aggregators.append(lambda x: {"sigma": x["sigma"]})561
562state = classification_lib.training_loop(563module=module,564rng=rng,565train_ds=train_ds,566eval_ds=eval_ds,567loss_fn=loss_fn,568optimizer=optimizer,569train_metrics_dict=train_metrics_dict,570eval_metrics_dict=eval_metrics_dict,571stats_aggregators=stats_aggregators,572config=config,573workdir=workdir)574return state575
576
577def main(argv):578del argv579
580# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make581# it unavailable to JAX.582tf.config.experimental.set_visible_devices([], "GPU")583
584state = train_and_evaluate(FLAGS.config, FLAGS.workdir)585del state586
587
588if __name__ == "__main__":589flags.mark_flags_as_required(["config", "workdir"])590app.run(main)591