google-research
87 строк · 2.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import numpy as np17import tensorflow as tf18# pylint: skip-file
19
20##################### sample method #####################
21
22def _sample_line(real, fake):23shape = [tf.shape(real)[0]] + [1] * (real.shape.ndims - 1)24alpha = tf.random.uniform(shape=shape, minval=0, maxval=1)25sample = real + alpha * (fake - real)26sample.set_shape(real.shape)27return sample28
29
30def _sample_DRAGAN(real, fake): # fake is useless31beta = tf.random.uniform(shape=tf.shape(real), minval=0, maxval=1)32fake = real + 0.5 * tf.math.reduce_std(real) * beta33sample = _sample_line(real, fake)34return sample35
36
37#################### gradient penalty ####################
38
39def _norm(x):40norm = tf.norm(tf.reshape(x, [tf.shape(x)[0], -1]), axis=1)41return norm42
43
44def _one_mean_gp(grad):45norm = _norm(grad)46gp = tf.reduce_mean((norm - 1)**2)47return gp48
49
50def _zero_mean_gp(grad):51norm = _norm(grad)52gp = tf.reduce_mean(norm**2)53return gp54
55
56def _lipschitz_penalty(grad):57norm = _norm(grad)58gp = tf.reduce_mean(tf.maximum(norm - 1, 0)**2)59return gp60
61
62def gradient_penalty(f, real, fake, gp_mode, sample_mode):63sample_fns = {64'line': _sample_line,65'real': lambda real, fake: real,66'fake': lambda real, fake: fake,67'dragan': _sample_DRAGAN,68}69
70gp_fns = {71'1-gp': _one_mean_gp,72'0-gp': _zero_mean_gp,73'lp': _lipschitz_penalty,74}75
76if gp_mode == 'none':77gp = tf.constant(0, dtype=real.dtype)78else:79x = sample_fns[sample_mode](real, fake)80grad = tf.gradients(f(x), x)[0]81gp = gp_fns[gp_mode](grad)82
83return gp84
85#################### data format ####################
86def to_uint8(images, min_value=0, max_value=255, dtype=np.uint8):87return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype)88
89