google-research

Форк
0
87 строк · 2.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import numpy as np
17
import tensorflow as tf
18
# pylint: skip-file
19

20
#####################  sample method #####################
21

22
def _sample_line(real, fake):
23
  shape = [tf.shape(real)[0]] + [1] * (real.shape.ndims - 1)
24
  alpha = tf.random.uniform(shape=shape, minval=0, maxval=1)
25
  sample = real + alpha * (fake - real)
26
  sample.set_shape(real.shape)
27
  return sample
28

29

30
def _sample_DRAGAN(real, fake):  # fake is useless
31
  beta = tf.random.uniform(shape=tf.shape(real), minval=0, maxval=1)
32
  fake = real + 0.5 * tf.math.reduce_std(real) * beta
33
  sample = _sample_line(real, fake)
34
  return sample
35

36

37
####################  gradient penalty  ####################
38

39
def _norm(x):
40
  norm = tf.norm(tf.reshape(x, [tf.shape(x)[0], -1]), axis=1)
41
  return norm
42

43

44
def _one_mean_gp(grad):
45
  norm = _norm(grad)
46
  gp = tf.reduce_mean((norm - 1)**2)
47
  return gp
48

49

50
def _zero_mean_gp(grad):
51
  norm = _norm(grad)
52
  gp = tf.reduce_mean(norm**2)
53
  return gp
54

55

56
def _lipschitz_penalty(grad):
57
  norm = _norm(grad)
58
  gp = tf.reduce_mean(tf.maximum(norm - 1, 0)**2)
59
  return gp
60

61

62
def gradient_penalty(f, real, fake, gp_mode, sample_mode):
63
  sample_fns = {
64
    'line': _sample_line,
65
    'real': lambda real, fake: real,
66
    'fake': lambda real, fake: fake,
67
    'dragan': _sample_DRAGAN,
68
  }
69

70
  gp_fns = {
71
    '1-gp': _one_mean_gp,
72
    '0-gp': _zero_mean_gp,
73
    'lp': _lipschitz_penalty,
74
  }
75

76
  if gp_mode == 'none':
77
    gp = tf.constant(0, dtype=real.dtype)
78
  else:
79
    x = sample_fns[sample_mode](real, fake)
80
    grad = tf.gradients(f(x), x)[0]
81
    gp = gp_fns[gp_mode](grad)
82

83
  return gp
84

85
####################  data format  ####################
86
def to_uint8(images, min_value=0, max_value=255, dtype=np.uint8):
87
  return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype)
88

89

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.