google-research
271 строка · 9.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""An implementation of a GAIL discriminator (https://arxiv.org/abs/1606.03476).
17
18In order to make training more stable, this implementation also uses
19gradient penalty from WGAN-GP (https://arxiv.org/abs/1704.00028) or spectral
20normalization (https://openreview.net/forum?id=B1QRgziT-).
21"""
22import tensorflow.compat.v2 as tf
23import tensorflow_gan.python.losses.losses_impl as tfgan_losses
24
25
26class SpectralNorm(tf.keras.layers.Wrapper):
27"""Spectral Norm wrapper for tf.layers.Dense."""
28
29def build(self, input_shape):
30assert isinstance(self.layer,
31tf.keras.layers.Dense), 'The class wraps only Dense layer'
32if not self.layer.built:
33self.layer.build(input_shape)
34
35self.kernel = self.layer.kernel
36
37shape = self.kernel.shape
38
39self.u = tf.random.truncated_normal(
40shape=[1, shape[-1]], dtype=tf.float32)
41
42def call(self, inputs, training=True):
43u = self.u
44u_wt = tf.matmul(u, self.kernel, transpose_b=True)
45u_wt_norm = tf.nn.l2_normalize(u_wt)
46u_wt_w_norm = tf.nn.l2_normalize(tf.matmul(u_wt_norm, self.kernel))
47sigma = tf.squeeze(
48tf.matmul(
49tf.matmul(u_wt_norm, self.kernel), u_wt_w_norm, transpose_b=True))
50self.layer.kernel = self.kernel / sigma
51
52if training:
53self.u = u_wt_w_norm
54return self.layer(inputs)
55
56
57class RatioGANSN(object):
58"""An implementation of GAIL discriminator with spectral normalization (https://openreview.net/forum?id=B1QRgziT-)."""
59
60def __init__(self, state_dim, action_dim, log_interval):
61"""Creates an instance of the discriminator.
62
63Args:
64state_dim: State size.
65action_dim: Action size.
66log_interval: Log losses every N steps.
67"""
68dense = tf.keras.layers.Dense
69self.discriminator = tf.keras.Sequential([
70SpectralNorm(
71dense(256, activation=tf.nn.tanh, kernel_initializer='orthogonal')),
72SpectralNorm(
73dense(256, activation=tf.nn.tanh, kernel_initializer='orthogonal')),
74SpectralNorm(dense(1, kernel_initializer='orthogonal'))
75])
76
77self.discriminator.build(input_shape=(None, state_dim + action_dim))
78
79self.log_interval = log_interval
80
81self.avg_loss = tf.keras.metrics.Mean('gail loss', dtype=tf.float32)
82
83self.optimizer = tf.keras.optimizers.Adam()
84
85@tf.function
86def get_occupancy_ratio(self, states, actions):
87"""Returns occupancy ratio between two policies.
88
89Args:
90states: A batch of states.
91actions: A batch of actions.
92
93Returns:
94A batch of occupancy ratios.
95"""
96return tf.exp(self.get_log_occupancy_ratio(states, actions))
97
98@tf.function
99def get_log_occupancy_ratio(self, states, actions):
100"""Returns occupancy ratio between two policies.
101
102Args:
103states: A batch of states.
104actions: A batch of actions.
105
106Returns:
107A batch of occupancy ratios.
108"""
109inputs = tf.concat([states, actions], -1)
110return self.discriminator(inputs, training=False)
111
112@tf.function
113def update(self, expert_dataset_iter, replay_buffer_iter):
114"""Performs a single training step for critic and actor.
115
116Args:
117expert_dataset_iter: An tensorflow graph iteratable over expert data.
118replay_buffer_iter: An tensorflow graph iteratable over replay buffer.
119"""
120expert_states, expert_actions, _ = next(expert_dataset_iter)
121policy_states, policy_actions, _, _, _ = next(replay_buffer_iter)[0]
122
123policy_inputs = tf.concat([policy_states, policy_actions], -1)
124expert_inputs = tf.concat([expert_states, expert_actions], -1)
125
126with tf.GradientTape(watch_accessed_variables=False) as tape:
127tape.watch(self.discriminator.variables)
128inputs = tf.concat([policy_inputs, expert_inputs], 0)
129outputs = self.discriminator(inputs)
130
131policy_output, expert_output = tf.split(
132outputs, num_or_size_splits=2, axis=0)
133
134# Using the standard value for label smoothing instead of 0.25.
135classification_loss = tfgan_losses.modified_discriminator_loss(
136expert_output, policy_output, label_smoothing=0.0)
137
138grads = tape.gradient(classification_loss, self.discriminator.variables)
139
140self.optimizer.apply_gradients(zip(grads, self.discriminator.variables))
141
142self.avg_loss(classification_loss)
143
144if tf.equal(self.optimizer.iterations % self.log_interval, 0):
145tf.summary.scalar(
146'train gail/loss',
147self.avg_loss.result(),
148step=self.optimizer.iterations)
149self.avg_loss.reset_states()
150
151
152class RatioGANGP(object):
153"""An implementation of GAIL discriminator with gradient penalty (https://arxiv.org/abs/1704.00028)."""
154
155def __init__(self, state_dim, action_dim, log_interval,
156grad_penalty_coeff=10):
157"""Creates an instance of the discriminator.
158
159Args:
160state_dim: State size.
161action_dim: Action size.
162log_interval: Log losses every N steps.
163grad_penalty_coeff: A cofficient for gradient penalty.
164"""
165self.discriminator = tf.keras.Sequential([
166tf.keras.layers.Dense(
167256, input_shape=(state_dim + action_dim,), activation=tf.nn.tanh),
168tf.keras.layers.Dense(256, activation=tf.nn.tanh),
169tf.keras.layers.Dense(1)
170])
171
172self.log_interval = log_interval
173self.grad_penalty_coeff = grad_penalty_coeff
174
175self.avg_classification_loss = tf.keras.metrics.Mean(
176'classification loss', dtype=tf.float32)
177self.avg_gp_loss = tf.keras.metrics.Mean(
178'gradient penalty', dtype=tf.float32)
179self.avg_total_loss = tf.keras.metrics.Mean(
180'total gan loss', dtype=tf.float32)
181
182self.optimizer = tf.keras.optimizers.Adam()
183
184@tf.function
185def get_occupancy_ratio(self, states, actions):
186"""Returns occupancy ratio between two policies.
187
188Args:
189states: A batch of states.
190actions: A batch of actions.
191
192Returns:
193A batch of occupancy ratios.
194"""
195inputs = tf.concat([states, actions], -1)
196return tf.exp(self.discriminator(inputs))
197
198@tf.function
199def get_log_occupancy_ratio(self, states, actions):
200"""Returns occupancy ratio between two policies.
201
202Args:
203states: A batch of states.
204actions: A batch of actions.
205
206Returns:
207A batch of occupancy ratios.
208"""
209inputs = tf.concat([states, actions], -1)
210return self.discriminator(inputs)
211
212@tf.function
213def update(self, expert_dataset_iter, replay_buffer_iter):
214"""Performs a single training step for critic and actor.
215
216Args:
217expert_dataset_iter: An tensorflow graph iteratable over expert data.
218replay_buffer_iter: An tensorflow graph iteratable over replay buffer.
219"""
220expert_states, expert_actions, _ = next(expert_dataset_iter)
221policy_states, policy_actions, _, _, _ = next(replay_buffer_iter)[0]
222
223policy_inputs = tf.concat([policy_states, policy_actions], -1)
224expert_inputs = tf.concat([expert_states, expert_actions], -1)
225
226alpha = tf.random.uniform(shape=(policy_inputs.get_shape()[0], 1))
227inter = alpha * policy_inputs + (1 - alpha) * expert_inputs
228
229with tf.GradientTape(watch_accessed_variables=False) as tape:
230tape.watch(self.discriminator.variables)
231policy_output = self.discriminator(policy_inputs)
232expert_output = self.discriminator(expert_inputs)
233
234# Using the standard value for label smoothing instead of 0.25.
235classification_loss = tfgan_losses.modified_discriminator_loss(
236expert_output, policy_output, label_smoothing=0.0)
237
238with tf.GradientTape(watch_accessed_variables=False) as tape2:
239tape2.watch(inter)
240output = self.discriminator(inter)
241
242grad = tape2.gradient(output, [inter])[0]
243grad_penalty = tf.reduce_mean(tf.pow(tf.norm(grad, axis=-1) - 1, 2))
244total_loss = classification_loss + self.grad_penalty_coeff * grad_penalty
245
246grads = tape.gradient(total_loss, self.discriminator.variables)
247
248self.optimizer.apply_gradients(zip(grads, self.discriminator.variables))
249
250self.avg_classification_loss(classification_loss)
251self.avg_gp_loss(grad_penalty)
252self.avg_total_loss(total_loss)
253
254if tf.equal(self.optimizer.iterations % self.log_interval, 0):
255tf.summary.scalar(
256'train gail/classification loss',
257self.avg_classification_loss.result(),
258step=self.optimizer.iterations)
259self.avg_classification_loss.reset_states()
260
261tf.summary.scalar(
262'train gail/gradient penalty',
263self.avg_gp_loss.result(),
264step=self.optimizer.iterations)
265self.avg_gp_loss.reset_states()
266
267tf.summary.scalar(
268'train gail/loss',
269self.avg_total_loss.result(),
270step=self.optimizer.iterations)
271self.avg_total_loss.reset_states()
272