google-research

Форк
0
271 строка · 9.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""An implementation of a GAIL discriminator (https://arxiv.org/abs/1606.03476).
17

18
In order to make training more stable, this implementation also uses
19
gradient penalty from WGAN-GP (https://arxiv.org/abs/1704.00028) or spectral
20
normalization (https://openreview.net/forum?id=B1QRgziT-).
21
"""
22
import tensorflow.compat.v2 as tf
23
import tensorflow_gan.python.losses.losses_impl as tfgan_losses
24

25

26
class SpectralNorm(tf.keras.layers.Wrapper):
27
  """Spectral Norm wrapper for tf.layers.Dense."""
28

29
  def build(self, input_shape):
30
    assert isinstance(self.layer,
31
                      tf.keras.layers.Dense), 'The class wraps only Dense layer'
32
    if not self.layer.built:
33
      self.layer.build(input_shape)
34

35
      self.kernel = self.layer.kernel
36

37
      shape = self.kernel.shape
38

39
      self.u = tf.random.truncated_normal(
40
          shape=[1, shape[-1]], dtype=tf.float32)
41

42
  def call(self, inputs, training=True):
43
    u = self.u
44
    u_wt = tf.matmul(u, self.kernel, transpose_b=True)
45
    u_wt_norm = tf.nn.l2_normalize(u_wt)
46
    u_wt_w_norm = tf.nn.l2_normalize(tf.matmul(u_wt_norm, self.kernel))
47
    sigma = tf.squeeze(
48
        tf.matmul(
49
            tf.matmul(u_wt_norm, self.kernel), u_wt_w_norm, transpose_b=True))
50
    self.layer.kernel = self.kernel / sigma
51

52
    if training:
53
      self.u = u_wt_w_norm
54
    return self.layer(inputs)
55

56

57
class RatioGANSN(object):
58
  """An implementation of GAIL discriminator with spectral normalization (https://openreview.net/forum?id=B1QRgziT-)."""
59

60
  def __init__(self, state_dim, action_dim, log_interval):
61
    """Creates an instance of the discriminator.
62

63
    Args:
64
      state_dim: State size.
65
      action_dim: Action size.
66
      log_interval: Log losses every N steps.
67
    """
68
    dense = tf.keras.layers.Dense
69
    self.discriminator = tf.keras.Sequential([
70
        SpectralNorm(
71
            dense(256, activation=tf.nn.tanh, kernel_initializer='orthogonal')),
72
        SpectralNorm(
73
            dense(256, activation=tf.nn.tanh, kernel_initializer='orthogonal')),
74
        SpectralNorm(dense(1, kernel_initializer='orthogonal'))
75
    ])
76

77
    self.discriminator.build(input_shape=(None, state_dim + action_dim))
78

79
    self.log_interval = log_interval
80

81
    self.avg_loss = tf.keras.metrics.Mean('gail loss', dtype=tf.float32)
82

83
    self.optimizer = tf.keras.optimizers.Adam()
84

85
  @tf.function
86
  def get_occupancy_ratio(self, states, actions):
87
    """Returns occupancy ratio between two policies.
88

89
    Args:
90
      states: A batch of states.
91
      actions: A batch of actions.
92

93
    Returns:
94
      A batch of occupancy ratios.
95
    """
96
    return tf.exp(self.get_log_occupancy_ratio(states, actions))
97

98
  @tf.function
99
  def get_log_occupancy_ratio(self, states, actions):
100
    """Returns occupancy ratio between two policies.
101

102
    Args:
103
      states: A batch of states.
104
      actions: A batch of actions.
105

106
    Returns:
107
      A batch of occupancy ratios.
108
    """
109
    inputs = tf.concat([states, actions], -1)
110
    return self.discriminator(inputs, training=False)
111

112
  @tf.function
113
  def update(self, expert_dataset_iter, replay_buffer_iter):
114
    """Performs a single training step for critic and actor.
115

116
    Args:
117
      expert_dataset_iter: An tensorflow graph iteratable over expert data.
118
      replay_buffer_iter: An tensorflow graph iteratable over replay buffer.
119
    """
120
    expert_states, expert_actions, _ = next(expert_dataset_iter)
121
    policy_states, policy_actions, _, _, _ = next(replay_buffer_iter)[0]
122

123
    policy_inputs = tf.concat([policy_states, policy_actions], -1)
124
    expert_inputs = tf.concat([expert_states, expert_actions], -1)
125

126
    with tf.GradientTape(watch_accessed_variables=False) as tape:
127
      tape.watch(self.discriminator.variables)
128
      inputs = tf.concat([policy_inputs, expert_inputs], 0)
129
      outputs = self.discriminator(inputs)
130

131
      policy_output, expert_output = tf.split(
132
          outputs, num_or_size_splits=2, axis=0)
133

134
      # Using the standard value for label smoothing instead of 0.25.
135
      classification_loss = tfgan_losses.modified_discriminator_loss(
136
          expert_output, policy_output, label_smoothing=0.0)
137

138
    grads = tape.gradient(classification_loss, self.discriminator.variables)
139

140
    self.optimizer.apply_gradients(zip(grads, self.discriminator.variables))
141

142
    self.avg_loss(classification_loss)
143

144
    if tf.equal(self.optimizer.iterations % self.log_interval, 0):
145
      tf.summary.scalar(
146
          'train gail/loss',
147
          self.avg_loss.result(),
148
          step=self.optimizer.iterations)
149
      self.avg_loss.reset_states()
150

151

152
class RatioGANGP(object):
153
  """An implementation of GAIL discriminator with gradient penalty  (https://arxiv.org/abs/1704.00028)."""
154

155
  def __init__(self, state_dim, action_dim, log_interval,
156
               grad_penalty_coeff=10):
157
    """Creates an instance of the discriminator.
158

159
    Args:
160
      state_dim: State size.
161
      action_dim: Action size.
162
      log_interval: Log losses every N steps.
163
      grad_penalty_coeff: A cofficient for gradient penalty.
164
    """
165
    self.discriminator = tf.keras.Sequential([
166
        tf.keras.layers.Dense(
167
            256, input_shape=(state_dim + action_dim,), activation=tf.nn.tanh),
168
        tf.keras.layers.Dense(256, activation=tf.nn.tanh),
169
        tf.keras.layers.Dense(1)
170
    ])
171

172
    self.log_interval = log_interval
173
    self.grad_penalty_coeff = grad_penalty_coeff
174

175
    self.avg_classification_loss = tf.keras.metrics.Mean(
176
        'classification loss', dtype=tf.float32)
177
    self.avg_gp_loss = tf.keras.metrics.Mean(
178
        'gradient penalty', dtype=tf.float32)
179
    self.avg_total_loss = tf.keras.metrics.Mean(
180
        'total gan loss', dtype=tf.float32)
181

182
    self.optimizer = tf.keras.optimizers.Adam()
183

184
  @tf.function
185
  def get_occupancy_ratio(self, states, actions):
186
    """Returns occupancy ratio between two policies.
187

188
    Args:
189
      states: A batch of states.
190
      actions: A batch of actions.
191

192
    Returns:
193
      A batch of occupancy ratios.
194
    """
195
    inputs = tf.concat([states, actions], -1)
196
    return tf.exp(self.discriminator(inputs))
197

198
  @tf.function
199
  def get_log_occupancy_ratio(self, states, actions):
200
    """Returns occupancy ratio between two policies.
201

202
    Args:
203
      states: A batch of states.
204
      actions: A batch of actions.
205

206
    Returns:
207
      A batch of occupancy ratios.
208
    """
209
    inputs = tf.concat([states, actions], -1)
210
    return self.discriminator(inputs)
211

212
  @tf.function
213
  def update(self, expert_dataset_iter, replay_buffer_iter):
214
    """Performs a single training step for critic and actor.
215

216
    Args:
217
      expert_dataset_iter: An tensorflow graph iteratable over expert data.
218
      replay_buffer_iter: An tensorflow graph iteratable over replay buffer.
219
    """
220
    expert_states, expert_actions, _ = next(expert_dataset_iter)
221
    policy_states, policy_actions, _, _, _ = next(replay_buffer_iter)[0]
222

223
    policy_inputs = tf.concat([policy_states, policy_actions], -1)
224
    expert_inputs = tf.concat([expert_states, expert_actions], -1)
225

226
    alpha = tf.random.uniform(shape=(policy_inputs.get_shape()[0], 1))
227
    inter = alpha * policy_inputs + (1 - alpha) * expert_inputs
228

229
    with tf.GradientTape(watch_accessed_variables=False) as tape:
230
      tape.watch(self.discriminator.variables)
231
      policy_output = self.discriminator(policy_inputs)
232
      expert_output = self.discriminator(expert_inputs)
233

234
      # Using the standard value for label smoothing instead of 0.25.
235
      classification_loss = tfgan_losses.modified_discriminator_loss(
236
          expert_output, policy_output, label_smoothing=0.0)
237

238
      with tf.GradientTape(watch_accessed_variables=False) as tape2:
239
        tape2.watch(inter)
240
        output = self.discriminator(inter)
241

242
      grad = tape2.gradient(output, [inter])[0]
243
      grad_penalty = tf.reduce_mean(tf.pow(tf.norm(grad, axis=-1) - 1, 2))
244
      total_loss = classification_loss + self.grad_penalty_coeff * grad_penalty
245

246
    grads = tape.gradient(total_loss, self.discriminator.variables)
247

248
    self.optimizer.apply_gradients(zip(grads, self.discriminator.variables))
249

250
    self.avg_classification_loss(classification_loss)
251
    self.avg_gp_loss(grad_penalty)
252
    self.avg_total_loss(total_loss)
253

254
    if tf.equal(self.optimizer.iterations % self.log_interval, 0):
255
      tf.summary.scalar(
256
          'train gail/classification loss',
257
          self.avg_classification_loss.result(),
258
          step=self.optimizer.iterations)
259
      self.avg_classification_loss.reset_states()
260

261
      tf.summary.scalar(
262
          'train gail/gradient penalty',
263
          self.avg_gp_loss.result(),
264
          step=self.optimizer.iterations)
265
      self.avg_gp_loss.reset_states()
266

267
      tf.summary.scalar(
268
          'train gail/loss',
269
          self.avg_total_loss.result(),
270
          step=self.optimizer.iterations)
271
      self.avg_total_loss.reset_states()
272

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.