google-research

Форк
0
/
neural_dual_dice.py 
408 строк · 16.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Neural network approximation of density ratio using DualDICE.
17

18
Based on the paper `DualDICE: Behavior-Agnostic Estimation of Discounted
19
Stationary Distribution Corrections' by Ofir Nachum, Yinlam Chow, Bo Dai,
20
and Lihong Li. See https://arxiv.org/abs/1906.04733
21
"""
22

23
from __future__ import absolute_import
24
from __future__ import division
25
from __future__ import print_function
26

27
import numpy as np
28
import tensorflow.compat.v1 as tf
29
from typing import Callable, Optional, Text
30

31
from dual_dice import policy as policy_lib
32
from dual_dice.algos import base as base_algo
33

34

35
class NeuralSolverParameters(object):
36
  """Set of parameters common to neural network solvers."""
37

38
  def __init__(self,
39
               state_dim,
40
               action_dim,
41
               gamma,
42
               discrete_actions = True,
43
               deterministic_env = False,
44
               hidden_dim = 64,
45
               hidden_layers = 1,
46
               activation = tf.nn.tanh,
47
               nu_learning_rate = 0.0001,
48
               zeta_learning_rate = 0.001,
49
               batch_size = 512,
50
               num_steps = 10000,
51
               log_every = 500,
52
               smooth_over = 4,
53
               summary_writer = None,
54
               summary_prefix = ''):
55
    """Initializes the parameters.
56

57
    Args:
58
      state_dim: Dimension of state observations.
59
      action_dim: If the environment uses continuous actions, this should be
60
        the dimension of the actions. If the environment uses discrete actions,
61
        this should be the number of discrete actions.
62
      gamma: The discount to use.
63
      discrete_actions: Whether the environment uses discrete actions or not.
64
      deterministic_env: Whether to take advantage of a deterministic
65
        environment. If this and average_next_nu are both True, the optimization
66
        for nu is performed agnostic to zeta (in the primal form).
67
      hidden_dim: The internal dimension of the neural networks.
68
      hidden_layers: Number of internal layers in the neural networks.
69
      activation: Activation to use in the neural networks.
70
      nu_learning_rate: Learning rate for nu.
71
      zeta_learning_rate: Learning rate for zeta.
72
      batch_size: Batch size.
73
      num_steps: Number of steps (batches) to train for.
74
      log_every: Log progress and debug information every so many steps.
75
      smooth_over: Number of iterations to smooth over for final value estimate.
76
      summary_writer: An optional summary writer to log information to.
77
      summary_prefix: A prefix to prepend to the summary tags.
78
    """
79
    self.state_dim = state_dim
80
    self.action_dim = action_dim
81
    self.gamma = gamma
82
    self.discrete_actions = discrete_actions
83
    self.deterministic_env = deterministic_env
84
    self.hidden_dim = hidden_dim
85
    self.hidden_layers = hidden_layers
86
    self.activation = activation
87
    self.nu_learning_rate = nu_learning_rate
88
    self.zeta_learning_rate = zeta_learning_rate
89
    self.batch_size = batch_size
90
    self.num_steps = num_steps
91
    self.log_every = log_every
92
    self.smooth_over = smooth_over
93
    self.summary_writer = summary_writer
94
    self.summary_prefix = summary_prefix
95

96

97
class NeuralDualDice(base_algo.BaseAlgo):
98
  """Approximate the density ratio using neural networks."""
99

100
  def __init__(self,
101
               parameters,
102
               solve_for_state_action_ratio = True,
103
               average_next_nu = True,
104
               average_samples = 1,
105
               function_exponent = 1.5):
106
    """Initializes the solver.
107

108
    Args:
109
      parameters: An object holding the common neural network parameters.
110
      solve_for_state_action_ratio: Whether to solve for state-action density
111
        ratio. Defaults to True, which is recommended, since solving for the
112
        state density ratio requires importance weights which can introduce
113
        training instability.
114
      average_next_nu: Whether to take an empirical expectation over next nu.
115
        This can improve stability of training.
116
      average_samples: Number of empirical samples to average over for next nu
117
        computation (only relevant in continuous environments).
118
      function_exponent: The form of the function f(x). We use a polynomial
119
        f(x)=|x|^p / p where p is function_exponent.
120

121
    Raises:
122
      ValueError: If function_exponent is less than or equal to 1.
123
      NotImplementedError: If actions are continuous.
124
    """
125
    self._parameters = parameters
126
    self._solve_for_state_action_ratio = solve_for_state_action_ratio
127
    self._average_next_nu = average_next_nu
128
    self._average_samples = average_samples
129

130
    if not self._parameters.discrete_actions:
131
      raise NotImplementedError('Continuous actions are not fully supported.')
132

133
    if function_exponent <= 1:
134
      raise ValueError('Exponent for f must be at least 1.')
135

136
    # Conjugate of f(x) = |x|^p / p is f*(x) = |x|^q / q where q = p / (p - 1).
137
    conjugate_exponent = function_exponent / (function_exponent - 1)
138
    self._f = lambda x: tf.abs(x) ** function_exponent / function_exponent
139
    self._fstar = lambda x: tf.abs(x) ** conjugate_exponent / conjugate_exponent
140

141
    # Build and initialize graph.
142
    self._build_graph()
143
    self._session = tf.Session()
144
    self._session.run(tf.global_variables_initializer())
145

146
  def _build_graph(self):
147
    self._create_placeholders()
148

149
    # Convert discrete actions to one-hot vectors.
150
    action = tf.one_hot(self._action, self._parameters.action_dim)
151
    next_action = tf.one_hot(self._next_action, self._parameters.action_dim)
152
    initial_action = tf.one_hot(self._initial_action,
153
                                self._parameters.action_dim)
154

155
    nu, next_nu, initial_nu, zeta = self._compute_values(
156
        action, next_action, initial_action)
157

158
    # Density ratio given by approximated zeta values.
159
    self._density_ratio = zeta
160

161
    if self._solve_for_state_action_ratio:
162
      delta_nu = nu - next_nu * self._parameters.gamma
163
    else:
164
      delta_nu = nu - next_nu * self._parameters.gamma * self._policy_ratio
165

166
    unweighted_zeta_loss = (delta_nu * zeta - self._fstar(zeta) -
167
                            (1 - self._parameters.gamma) * initial_nu)
168
    self._zeta_loss = -(tf.reduce_sum(self._weights * unweighted_zeta_loss) /
169
                        tf.reduce_sum(self._weights))
170

171
    if self._parameters.deterministic_env and self._average_next_nu:
172
      # Dont use Fenchel conjugate trick and instead optimize primal.
173
      unweighted_nu_loss = (self._f(delta_nu) -
174
                            (1 - self._parameters.gamma) * initial_nu)
175
      self._nu_loss = (tf.reduce_sum(self._weights * unweighted_nu_loss) /
176
                       tf.reduce_sum(self._weights))
177
    else:
178
      self._nu_loss = -self._zeta_loss
179

180
    self._train_nu_op = tf.train.AdamOptimizer(
181
        self._parameters.nu_learning_rate).minimize(
182
            self._nu_loss, var_list=tf.trainable_variables('nu'))
183
    self._train_zeta_op = tf.train.AdamOptimizer(
184
        self._parameters.zeta_learning_rate).minimize(
185
            self._zeta_loss, var_list=tf.trainable_variables('zeta'))
186
    self._train_op = tf.group(self._train_nu_op, self._train_zeta_op)
187

188
    # Debug quantity (should be close to 1).
189
    self._debug = (
190
        tf.reduce_sum(self._weights * self._density_ratio) /
191
        tf.reduce_sum(self._weights))
192

193
  def _create_placeholders(self):
194
    self._state = tf.placeholder(
195
        tf.float32, [None, self._parameters.state_dim], 'state')
196
    self._next_state = tf.placeholder(
197
        tf.float32, [None, self._parameters.state_dim], 'next_state')
198
    self._initial_state = tf.placeholder(
199
        tf.float32, [None, self._parameters.state_dim], 'initial_state')
200

201
    self._action = tf.placeholder(tf.int32, [None], 'action')
202
    self._next_action = tf.placeholder(tf.int32, [None], 'next_action')
203
    self._initial_action = tf.placeholder(tf.int32, [None], 'initial_action')
204

205
    # Ratio of policy sampling probabilities of self._action.
206
    self._policy_ratio = tf.placeholder(tf.float32, [None], 'policy_ratio')
207

208
    # Policy sampling probabilities associated with next state.
209
    self._target_policy_next_probs = tf.placeholder(
210
        tf.float32, [None, self._parameters.action_dim])
211

212
    self._weights = tf.placeholder(tf.float32, [None], 'weights')
213

214
  def _compute_values(self, action, next_action, initial_action):
215
    nu = self._nu_network(self._state, action)
216
    initial_nu = self._nu_network(self._initial_state, initial_action)
217

218
    if self._average_next_nu:
219
      # Average next nu over all actions weighted by target policy
220
      # probabilities.
221
      all_next_actions = [
222
          tf.one_hot(act * tf.ones_like(self._next_action),
223
                     self._parameters.action_dim)
224
          for act in range(self._parameters.action_dim)]
225
      all_next_nu = [self._nu_network(self._next_state, next_action_i)
226
                     for next_action_i in all_next_actions]
227
      next_nu = sum(
228
          self._target_policy_next_probs[:, act_index] * all_next_nu[act_index]
229
          for act_index in range(self._parameters.action_dim))
230
    else:
231
      next_nu = self._nu_network(self._next_state, next_action)
232

233
    zeta = self._zeta_network(self._state, action)
234

235
    return nu, next_nu, initial_nu, zeta
236

237
  def _nu_network(self, state, action):
238
    with tf.variable_scope('nu', reuse=tf.AUTO_REUSE):
239
      if self._solve_for_state_action_ratio:
240
        inputs = tf.concat([state, action], -1)
241
      else:
242
        inputs = state
243
      outputs = self._network(inputs)
244
    return outputs
245

246
  def _zeta_network(self, state, action):
247
    with tf.variable_scope('zeta', reuse=tf.AUTO_REUSE):
248
      if self._solve_for_state_action_ratio:
249
        inputs = tf.concat([state, action], -1)
250
      else:
251
        inputs = state
252
      outputs = self._network(inputs)
253
    return outputs
254

255
  def _network(self, inputs):
256
    with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
257
      input_dim = int(inputs.shape[-1])
258
      prev_dim = input_dim
259
      prev_outputs = inputs
260
      # Hidden layers.
261
      for layer in range(self._parameters.hidden_layers):
262
        with tf.variable_scope('layer%d' % layer, reuse=tf.AUTO_REUSE):
263
          weight = tf.get_variable(
264
              'weight', [prev_dim, self._parameters.hidden_dim],
265
              initializer=tf.glorot_uniform_initializer())
266
          bias = tf.get_variable(
267
              'bias', initializer=tf.zeros([self._parameters.hidden_dim]))
268
          pre_activation = tf.matmul(prev_outputs, weight) + bias
269
          post_activation = self._parameters.activation(pre_activation)
270
        prev_dim = self._parameters.hidden_dim
271
        prev_outputs = post_activation
272

273
      # Final layer.
274
      weight = tf.get_variable(
275
          'weight_final', [prev_dim, 1],
276
          initializer=tf.glorot_uniform_initializer())
277
      bias = tf.get_variable(
278
          'bias_final', [1], initializer=tf.zeros_initializer())
279
      output = tf.matmul(prev_outputs, weight) + bias
280
      return output[Ellipsis, 0]
281

282
  def solve(self, data, target_policy, baseline_policy=None):
283
    """Solves for density ratios and then approximates target policy value.
284

285
    Args:
286
      data: The transition data store to use.
287
      target_policy: The policy whose value we want to estimate.
288
      baseline_policy: The policy used to collect the data. If None,
289
        we default to data.policy.
290

291
    Returns:
292
      Estimated average per-step reward of the target policy.
293

294
    Raises:
295
      ValueError: If NaNs encountered in policy ratio computation.
296
    """
297
    if baseline_policy is None:
298
      baseline_policy = data.policy
299

300
    value_estimates = []
301
    for step in range(self._parameters.num_steps):
302
      batch = data.sample_batch(self._parameters.batch_size)
303
      feed_dict = {
304
          self._state: batch.state,
305
          self._action: batch.action,
306
          self._next_state: batch.next_state,
307
          self._initial_state: batch.initial_state,
308
          self._weights: self._parameters.gamma ** batch.time_step,
309
      }
310

311
      # On-policy next action and initial action.
312
      feed_dict[self._next_action] = target_policy.sample_action(
313
          batch.next_state)
314
      feed_dict[self._initial_action] = target_policy.sample_action(
315
          batch.initial_state)
316

317
      if self._average_next_nu:
318
        next_probabilities = target_policy.get_probabilities(batch.next_state)
319
        feed_dict[self._target_policy_next_probs] = next_probabilities
320

321
      policy_ratio = policy_lib.get_policy_ratio(baseline_policy, target_policy,
322
                                                 batch.state, batch.action)
323

324
      if np.any(np.isnan(policy_ratio)):
325
        raise ValueError('NaNs encountered in policy ratio: %s.' % policy_ratio)
326
      feed_dict[self._policy_ratio] = policy_ratio
327

328
      self._session.run(self._train_op, feed_dict=feed_dict)
329

330
      if step % self._parameters.log_every == 0:
331
        debug = self._session.run(self._debug, feed_dict=feed_dict)
332
        tf.logging.info('At step %d' % step)
333
        tf.logging.info('Debug: %s' % debug)
334
        value_estimate = self.estimate_average_reward(data, target_policy)
335
        tf.logging.info('Estimated value: %s' % value_estimate)
336
        value_estimates.append(value_estimate)
337
        tf.logging.info(
338
            'Estimated smoothed value: %s' %
339
            np.mean(value_estimates[-self._parameters.smooth_over:]))
340

341
        if self._parameters.summary_writer:
342
          summary = tf.Summary(value=[
343
              tf.Summary.Value(
344
                  tag='%sdebug' % self._parameters.summary_prefix,
345
                  simple_value=debug),
346
              tf.Summary.Value(
347
                  tag='%svalue_estimate' % self._parameters.summary_prefix,
348
                  simple_value=value_estimate)])
349
          self._parameters.summary_writer.add_summary(summary, step)
350

351
    value_estimate = self.estimate_average_reward(data, target_policy)
352
    tf.logging.info('Estimated value: %s' % value_estimate)
353
    value_estimates.append(value_estimate)
354
    tf.logging.info('Estimated smoothed value: %s' %
355
                    np.mean(value_estimates[-self._parameters.smooth_over:]))
356

357
    # Return estimate that is smoothed over last few iterates.
358
    return np.mean(value_estimates[-self._parameters.smooth_over:])
359

360
  def _state_action_density_ratio(self, state, action):
361
    batched = len(np.shape(state)) > 1
362
    if not batched:
363
      state = np.expand_dims(state, 0)
364
      action = np.expand_dims(action, 0)
365
    density_ratio = self._session.run(
366
        self._density_ratio,
367
        feed_dict={
368
            self._state: state,
369
            self._action: action
370
        })
371
    if not batched:
372
      return density_ratio[0]
373
    return density_ratio
374

375
  def _state_density_ratio(self, state):
376
    batched = len(np.shape(state)) > 1
377
    if not batched:
378
      state = np.expand_dims(state, 0)
379
    density_ratio = self._session.run(
380
        self._density_ratio, feed_dict={self._state: state})
381
    if not batched:
382
      return density_ratio[0]
383
    return density_ratio
384

385
  def estimate_average_reward(self, data, target_policy):
386
    """Estimates value (average per-step reward) of policy.
387

388
    The estimation is based on solved values of zeta, so one should call
389
    solve() before calling this function.
390

391
    Args:
392
      data: The transition data store to use.
393
      target_policy: The policy whose value we want to estimate.
394

395
    Returns:
396
      Estimated average per-step reward of the target policy.
397
    """
398
    if self._solve_for_state_action_ratio:
399
      return base_algo.estimate_value_from_state_action_ratios(
400
          data, self._parameters.gamma, self._state_action_density_ratio)
401
    else:
402
      return base_algo.estimate_value_from_state_ratios(
403
          data, target_policy, self._parameters.gamma,
404
          self._state_density_ratio)
405

406
  def close(self):
407
    tf.reset_default_graph()
408
    self._session.close()
409

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.