google-research
375 строк · 11.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# pylint: disable=logging-format-interpolation
17# pylint: disable=unused-import
18# pylint: disable=invalid-unary-operand-type
19# pylint: disable=g-long-lambda
20# pylint: disable=g-direct-tensorflow-import
21
22r"""Custom ops."""
23
24import os25import sys26
27from absl import logging28import numpy as np29import tensorflow.compat.v1 as tf30
31from differentiable_data_selection import common_utils32
33USE_BFLOAT16 = False34
35
36def floatx():37return tf.bfloat16 if USE_BFLOAT16 else tf.float3238
39
40def use_bfloat16():41global USE_BFLOAT1642USE_BFLOAT16 = True43
44
45def get_variable(name, shape, initializer, trainable=True,46convert_if_using_bfloat16=True):47"""Create variable and convert to `tf.bfloat16` if needed."""48w = tf.get_variable(name=name,49shape=shape,50initializer=initializer,51trainable=trainable,52use_resource=True)53if USE_BFLOAT16 and convert_if_using_bfloat16:54w = tf.cast(w, tf.bfloat16)55return w56
57
58def log_tensor(x, training):59"""Prints a tensor."""60if training:61logging.info(f'{x.name:<90} {x.device} {x.shape}')62
63
64def _conv_kernel_initializer(shape, dtype=None, partition_info=None):65"""Initialization for convolutional kernels."""66del partition_info67kernel_height, kernel_width, _, out_filters = shape[-4:]68fan_out = int(kernel_height * kernel_width * out_filters)69return tf.random.normal(70shape, mean=0.0, stddev=np.sqrt(2.0 / fan_out), dtype=dtype)71
72
73def _dense_kernel_initializer(shape, dtype=None, partition_info=None):74"""Initialization for dense kernels."""75del partition_info76init_range = 1. / np.sqrt(shape[-1])77return tf.random_uniform(shape, -init_range, init_range, dtype=dtype)78
79
80def conv2d(x, filter_size, num_out_filters, stride=1,81use_bias=False, padding='SAME', data_format='NHWC', name='conv2d',82w=None, b=None):83"""Conv."""84with tf.variable_scope(name):85num_inp_filters = x.shape[-1].value86
87w = tf.get_variable(88name='kernel',89shape=[filter_size, filter_size, num_inp_filters, num_out_filters],90initializer=_conv_kernel_initializer,91trainable=True,92use_resource=True)93
94if USE_BFLOAT16:95w = tf.cast(w, tf.bfloat16)96x = tf.nn.conv2d(x, w, strides=[1, stride, stride, 1],97padding=padding, data_format=data_format)98
99if use_bias:100if b is None:101b = tf.get_variable(102name='bias',103shape=[num_out_filters],104initializer=tf.initializers.zeros(),105trainable=True,106use_resource=True)107if USE_BFLOAT16:108b = tf.cast(b, tf.bfloat16)109x = tf.nn.bias_add(x, b, name='bias_add')110return x111
112
113def dense(x, num_outputs, use_bias=True, name='dense'):114"""Custom fully connected layer."""115num_inputs = x.shape[-1].value116with tf.variable_scope(name):117w = tf.get_variable(118name='kernel',119shape=[num_inputs, num_outputs],120initializer=_dense_kernel_initializer,121trainable=True,122use_resource=True)123if USE_BFLOAT16:124w = tf.cast(w, tf.bfloat16)125
126x = tf.linalg.matmul(x, w)127if use_bias:128b = tf.get_variable(129name='bias',130shape=[num_outputs],131initializer=tf.initializers.zeros(),132trainable=True,133use_resource=True)134if USE_BFLOAT16:135b = tf.cast(b, tf.bfloat16)136x = tf.nn.bias_add(x, b, name='bias_add')137return x138
139
140def avg_pool(x, filter_size, stride, padding='SAME', name='avg_pool'):141"""Avg pool."""142x = tf.nn.avg_pool(143x,144ksize=[filter_size, filter_size],145strides=[1, stride, stride, 1],146padding=padding,147name=name)148return x149
150
151def max_pool(x, filter_size, stride, padding='SAME', name='max_pool'):152"""Avg pool."""153x = tf.nn.max_pool(154x,155ksize=[filter_size, filter_size],156strides=[1, stride, stride, 1],157padding=padding,158name=name)159return x160
161
162def relu(x, leaky=0.2, name='relu'):163"""Leaky ReLU."""164return tf.nn.leaky_relu(x, alpha=leaky, name=name)165
166
167def batch_norm(x, params, training, name='batch_norm', **kwargs):168"""Wrapped `batch_norm`."""169return sync_batch_norm(x, params, training, name=name, **kwargs)170
171
172def sync_batch_norm(x, params, training, name='batch_norm'):173"""Sync batch_norm."""174size = x.shape[-1].value175
176with tf.variable_scope(name):177gamma = tf.get_variable(name='gamma',178shape=[size],179initializer=tf.initializers.ones(),180trainable=True)181beta = tf.get_variable(name='beta',182shape=[size],183initializer=tf.initializers.zeros(),184trainable=True)185moving_mean = tf.get_variable(name='moving_mean',186shape=[size],187initializer=tf.initializers.zeros(),188trainable=False)189moving_variance = tf.get_variable(name='moving_variance',190shape=[size],191initializer=tf.initializers.ones(),192trainable=False)193
194x = tf.cast(x, tf.float32)195if training:196if params.use_tpu:197num_replicas = params.num_replicas198if num_replicas <= 8:199group_assign = None200group_shards = tf.cast(num_replicas, tf.float32)201else:202group_shards = max(8, num_replicas // 8)203
204# round to nearest power of 2205log_num_replicas = max(1, int(np.log(group_shards) / np.log(2.)))206group_shards = int(np.power(2., log_num_replicas))207
208group_assign = np.arange(num_replicas, dtype=np.int32)209group_assign = group_assign.reshape([-1, group_shards])210group_assign = group_assign.tolist()211group_shards = tf.cast(group_shards, tf.float32)212
213mean = tf.reduce_mean(x, [0, 1, 2])214mean = tf.tpu.cross_replica_sum(mean / group_shards, group_assign)215
216# Var[x] = E[x^2] - E[x]^2217mean_sq = tf.reduce_mean(tf.math.square(x), [0, 1, 2])218mean_sq = tf.tpu.cross_replica_sum(mean_sq / group_shards, group_assign)219variance = mean_sq - tf.math.square(mean)220else:221mean, variance = tf.nn.moments(x, [0, 1, 2])222
223x = tf.nn.batch_normalization(224x, mean=mean, variance=variance, offset=beta, scale=gamma,225variance_epsilon=params.batch_norm_epsilon)226
227if USE_BFLOAT16:228x = tf.cast(x, tf.bfloat16, name='batch_norm_recast')229
230if (isinstance(moving_mean, tf.Variable) and231isinstance(moving_variance, tf.Variable)):232decay = tf.cast(1. - params.batch_norm_decay, tf.float32)233def u(moving, normal, name):234if params.use_tpu:235num_replicas_fp = tf.cast(params.num_replicas, tf.float32)236normal = tf.tpu.cross_replica_sum(normal) / num_replicas_fp237diff = decay * (moving - normal)238return tf.assign_sub(moving, diff, use_locking=True, name=name)239tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,240u(moving_mean, mean, name='moving_mean'))241tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,242u(moving_variance, variance, name='moving_variance'))243return x244else:245return x, mean, variance246else:247if params.use_tpu:248x = tf.nn.batch_normalization(249x, mean=moving_mean, variance=moving_variance, offset=beta,250scale=gamma, variance_epsilon=params.batch_norm_epsilon)251else:252x, _, _ = tf.nn.fused_batch_norm(253x, scale=gamma, offset=beta, mean=moving_mean,254variance=moving_variance, epsilon=params.batch_norm_epsilon,255is_training=False)256
257if USE_BFLOAT16:258x = tf.cast(x, tf.bfloat16)259return x260
261
262def gpu_batch_norm(x, params, training=True, name='batch_norm'):263"""Async batch_norm."""264shape = [x.shape[-1].value]265with tf.variable_scope(name):266gamma = get_variable('gamma', shape, tf.initializers.ones())267beta = get_variable('beta', shape, tf.initializers.zeros())268moving_mean = tf.get_variable(269name='moving_mean',270shape=shape,271initializer=tf.initializers.zeros(),272trainable=False,273use_resource=True)274moving_variance = tf.get_variable(275name='moving_variance',276shape=shape,277initializer=tf.initializers.ones(),278trainable=False,279use_resource=True)280
281if training:282x, mean, variance = tf.nn.fused_batch_norm(283x,284scale=gamma,285offset=beta,286epsilon=params.batch_norm_epsilon,287is_training=True,288)289
290def u(moving, normal):291decay = tf.cast(1. - params.batch_norm_decay, tf.float32)292return moving.assign_sub(decay * (moving - normal))293
294tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u(moving_mean, mean))295tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u(moving_variance, variance))296else:297x, _, _ = tf.nn.fused_batch_norm(298x,299scale=gamma,300offset=beta,301mean=moving_mean,302variance=moving_variance,303epsilon=params.batch_norm_epsilon,304is_training=False,305)306return x307
308
309def wrn_block(x, params, num_out_filters, stride, training=True,310name='wrn_block'):311"""WideResNet block."""312
313with tf.variable_scope(name):314num_inp_filters = x.shape[-1].value315residual = x316with tf.variable_scope('conv_3x3_1'):317x = batch_norm(x, params, training)318if stride == 2 or num_inp_filters != num_out_filters:319residual = x320x = relu(x)321x = conv2d(x, 3, num_out_filters, stride)322
323with tf.variable_scope('conv_3x3_2'):324x = batch_norm(x, params, training)325x = relu(x)326x = conv2d(x, 3, num_out_filters, 1)327
328with tf.variable_scope('residual'):329if stride == 2 or num_inp_filters != num_out_filters:330residual = relu(residual)331residual = conv2d(residual, 1, num_out_filters, stride)332
333x = x + residual334log_tensor(x, True)335return x336
337
338def resnet_block(x, params, num_out_filters, stride, training=True,339bottleneck_rate=4, name='resnet_block'):340"""ResNet-50 bottleneck block."""341
342num_bottleneck_filters = num_out_filters // bottleneck_rate343with tf.variable_scope(name):344residual = x345num_inp_filters = residual.shape[-1].value346with tf.variable_scope('conv_1x1_1'):347x = conv2d(x, 1, num_bottleneck_filters, 1)348x = batch_norm(x, params, training)349x = relu(x, leaky=0.)350
351with tf.variable_scope('conv_3x3'):352x = conv2d(x, 3, num_bottleneck_filters, stride)353x = batch_norm(x, params, training)354x = relu(x, leaky=0.)355
356with tf.variable_scope('conv_1x1_2'):357x = conv2d(x, 1, num_out_filters, 1)358x = batch_norm(x, params, training)359
360with tf.variable_scope('residual'):361if stride == 2 or num_inp_filters != num_out_filters:362residual = conv2d(residual, 1, num_out_filters, stride)363residual = batch_norm(residual, params, training)364
365x = relu(x + residual, leaky=0.)366log_tensor(x, True)367return x368
369
370def dropout(x, drop_rate, training):371"""Dropout."""372if training:373return tf.nn.dropout(x, rate=drop_rate)374else:375return x376