google-research
408 строк · 16.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Neural network approximation of density ratio using DualDICE.
17
18Based on the paper `DualDICE: Behavior-Agnostic Estimation of Discounted
19Stationary Distribution Corrections' by Ofir Nachum, Yinlam Chow, Bo Dai,
20and Lihong Li. See https://arxiv.org/abs/1906.04733
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import numpy as np
28import tensorflow.compat.v1 as tf
29from typing import Callable, Optional, Text
30
31from dual_dice import policy as policy_lib
32from dual_dice.algos import base as base_algo
33
34
35class NeuralSolverParameters(object):
36"""Set of parameters common to neural network solvers."""
37
38def __init__(self,
39state_dim,
40action_dim,
41gamma,
42discrete_actions = True,
43deterministic_env = False,
44hidden_dim = 64,
45hidden_layers = 1,
46activation = tf.nn.tanh,
47nu_learning_rate = 0.0001,
48zeta_learning_rate = 0.001,
49batch_size = 512,
50num_steps = 10000,
51log_every = 500,
52smooth_over = 4,
53summary_writer = None,
54summary_prefix = ''):
55"""Initializes the parameters.
56
57Args:
58state_dim: Dimension of state observations.
59action_dim: If the environment uses continuous actions, this should be
60the dimension of the actions. If the environment uses discrete actions,
61this should be the number of discrete actions.
62gamma: The discount to use.
63discrete_actions: Whether the environment uses discrete actions or not.
64deterministic_env: Whether to take advantage of a deterministic
65environment. If this and average_next_nu are both True, the optimization
66for nu is performed agnostic to zeta (in the primal form).
67hidden_dim: The internal dimension of the neural networks.
68hidden_layers: Number of internal layers in the neural networks.
69activation: Activation to use in the neural networks.
70nu_learning_rate: Learning rate for nu.
71zeta_learning_rate: Learning rate for zeta.
72batch_size: Batch size.
73num_steps: Number of steps (batches) to train for.
74log_every: Log progress and debug information every so many steps.
75smooth_over: Number of iterations to smooth over for final value estimate.
76summary_writer: An optional summary writer to log information to.
77summary_prefix: A prefix to prepend to the summary tags.
78"""
79self.state_dim = state_dim
80self.action_dim = action_dim
81self.gamma = gamma
82self.discrete_actions = discrete_actions
83self.deterministic_env = deterministic_env
84self.hidden_dim = hidden_dim
85self.hidden_layers = hidden_layers
86self.activation = activation
87self.nu_learning_rate = nu_learning_rate
88self.zeta_learning_rate = zeta_learning_rate
89self.batch_size = batch_size
90self.num_steps = num_steps
91self.log_every = log_every
92self.smooth_over = smooth_over
93self.summary_writer = summary_writer
94self.summary_prefix = summary_prefix
95
96
97class NeuralDualDice(base_algo.BaseAlgo):
98"""Approximate the density ratio using neural networks."""
99
100def __init__(self,
101parameters,
102solve_for_state_action_ratio = True,
103average_next_nu = True,
104average_samples = 1,
105function_exponent = 1.5):
106"""Initializes the solver.
107
108Args:
109parameters: An object holding the common neural network parameters.
110solve_for_state_action_ratio: Whether to solve for state-action density
111ratio. Defaults to True, which is recommended, since solving for the
112state density ratio requires importance weights which can introduce
113training instability.
114average_next_nu: Whether to take an empirical expectation over next nu.
115This can improve stability of training.
116average_samples: Number of empirical samples to average over for next nu
117computation (only relevant in continuous environments).
118function_exponent: The form of the function f(x). We use a polynomial
119f(x)=|x|^p / p where p is function_exponent.
120
121Raises:
122ValueError: If function_exponent is less than or equal to 1.
123NotImplementedError: If actions are continuous.
124"""
125self._parameters = parameters
126self._solve_for_state_action_ratio = solve_for_state_action_ratio
127self._average_next_nu = average_next_nu
128self._average_samples = average_samples
129
130if not self._parameters.discrete_actions:
131raise NotImplementedError('Continuous actions are not fully supported.')
132
133if function_exponent <= 1:
134raise ValueError('Exponent for f must be at least 1.')
135
136# Conjugate of f(x) = |x|^p / p is f*(x) = |x|^q / q where q = p / (p - 1).
137conjugate_exponent = function_exponent / (function_exponent - 1)
138self._f = lambda x: tf.abs(x) ** function_exponent / function_exponent
139self._fstar = lambda x: tf.abs(x) ** conjugate_exponent / conjugate_exponent
140
141# Build and initialize graph.
142self._build_graph()
143self._session = tf.Session()
144self._session.run(tf.global_variables_initializer())
145
146def _build_graph(self):
147self._create_placeholders()
148
149# Convert discrete actions to one-hot vectors.
150action = tf.one_hot(self._action, self._parameters.action_dim)
151next_action = tf.one_hot(self._next_action, self._parameters.action_dim)
152initial_action = tf.one_hot(self._initial_action,
153self._parameters.action_dim)
154
155nu, next_nu, initial_nu, zeta = self._compute_values(
156action, next_action, initial_action)
157
158# Density ratio given by approximated zeta values.
159self._density_ratio = zeta
160
161if self._solve_for_state_action_ratio:
162delta_nu = nu - next_nu * self._parameters.gamma
163else:
164delta_nu = nu - next_nu * self._parameters.gamma * self._policy_ratio
165
166unweighted_zeta_loss = (delta_nu * zeta - self._fstar(zeta) -
167(1 - self._parameters.gamma) * initial_nu)
168self._zeta_loss = -(tf.reduce_sum(self._weights * unweighted_zeta_loss) /
169tf.reduce_sum(self._weights))
170
171if self._parameters.deterministic_env and self._average_next_nu:
172# Dont use Fenchel conjugate trick and instead optimize primal.
173unweighted_nu_loss = (self._f(delta_nu) -
174(1 - self._parameters.gamma) * initial_nu)
175self._nu_loss = (tf.reduce_sum(self._weights * unweighted_nu_loss) /
176tf.reduce_sum(self._weights))
177else:
178self._nu_loss = -self._zeta_loss
179
180self._train_nu_op = tf.train.AdamOptimizer(
181self._parameters.nu_learning_rate).minimize(
182self._nu_loss, var_list=tf.trainable_variables('nu'))
183self._train_zeta_op = tf.train.AdamOptimizer(
184self._parameters.zeta_learning_rate).minimize(
185self._zeta_loss, var_list=tf.trainable_variables('zeta'))
186self._train_op = tf.group(self._train_nu_op, self._train_zeta_op)
187
188# Debug quantity (should be close to 1).
189self._debug = (
190tf.reduce_sum(self._weights * self._density_ratio) /
191tf.reduce_sum(self._weights))
192
193def _create_placeholders(self):
194self._state = tf.placeholder(
195tf.float32, [None, self._parameters.state_dim], 'state')
196self._next_state = tf.placeholder(
197tf.float32, [None, self._parameters.state_dim], 'next_state')
198self._initial_state = tf.placeholder(
199tf.float32, [None, self._parameters.state_dim], 'initial_state')
200
201self._action = tf.placeholder(tf.int32, [None], 'action')
202self._next_action = tf.placeholder(tf.int32, [None], 'next_action')
203self._initial_action = tf.placeholder(tf.int32, [None], 'initial_action')
204
205# Ratio of policy sampling probabilities of self._action.
206self._policy_ratio = tf.placeholder(tf.float32, [None], 'policy_ratio')
207
208# Policy sampling probabilities associated with next state.
209self._target_policy_next_probs = tf.placeholder(
210tf.float32, [None, self._parameters.action_dim])
211
212self._weights = tf.placeholder(tf.float32, [None], 'weights')
213
214def _compute_values(self, action, next_action, initial_action):
215nu = self._nu_network(self._state, action)
216initial_nu = self._nu_network(self._initial_state, initial_action)
217
218if self._average_next_nu:
219# Average next nu over all actions weighted by target policy
220# probabilities.
221all_next_actions = [
222tf.one_hot(act * tf.ones_like(self._next_action),
223self._parameters.action_dim)
224for act in range(self._parameters.action_dim)]
225all_next_nu = [self._nu_network(self._next_state, next_action_i)
226for next_action_i in all_next_actions]
227next_nu = sum(
228self._target_policy_next_probs[:, act_index] * all_next_nu[act_index]
229for act_index in range(self._parameters.action_dim))
230else:
231next_nu = self._nu_network(self._next_state, next_action)
232
233zeta = self._zeta_network(self._state, action)
234
235return nu, next_nu, initial_nu, zeta
236
237def _nu_network(self, state, action):
238with tf.variable_scope('nu', reuse=tf.AUTO_REUSE):
239if self._solve_for_state_action_ratio:
240inputs = tf.concat([state, action], -1)
241else:
242inputs = state
243outputs = self._network(inputs)
244return outputs
245
246def _zeta_network(self, state, action):
247with tf.variable_scope('zeta', reuse=tf.AUTO_REUSE):
248if self._solve_for_state_action_ratio:
249inputs = tf.concat([state, action], -1)
250else:
251inputs = state
252outputs = self._network(inputs)
253return outputs
254
255def _network(self, inputs):
256with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
257input_dim = int(inputs.shape[-1])
258prev_dim = input_dim
259prev_outputs = inputs
260# Hidden layers.
261for layer in range(self._parameters.hidden_layers):
262with tf.variable_scope('layer%d' % layer, reuse=tf.AUTO_REUSE):
263weight = tf.get_variable(
264'weight', [prev_dim, self._parameters.hidden_dim],
265initializer=tf.glorot_uniform_initializer())
266bias = tf.get_variable(
267'bias', initializer=tf.zeros([self._parameters.hidden_dim]))
268pre_activation = tf.matmul(prev_outputs, weight) + bias
269post_activation = self._parameters.activation(pre_activation)
270prev_dim = self._parameters.hidden_dim
271prev_outputs = post_activation
272
273# Final layer.
274weight = tf.get_variable(
275'weight_final', [prev_dim, 1],
276initializer=tf.glorot_uniform_initializer())
277bias = tf.get_variable(
278'bias_final', [1], initializer=tf.zeros_initializer())
279output = tf.matmul(prev_outputs, weight) + bias
280return output[Ellipsis, 0]
281
282def solve(self, data, target_policy, baseline_policy=None):
283"""Solves for density ratios and then approximates target policy value.
284
285Args:
286data: The transition data store to use.
287target_policy: The policy whose value we want to estimate.
288baseline_policy: The policy used to collect the data. If None,
289we default to data.policy.
290
291Returns:
292Estimated average per-step reward of the target policy.
293
294Raises:
295ValueError: If NaNs encountered in policy ratio computation.
296"""
297if baseline_policy is None:
298baseline_policy = data.policy
299
300value_estimates = []
301for step in range(self._parameters.num_steps):
302batch = data.sample_batch(self._parameters.batch_size)
303feed_dict = {
304self._state: batch.state,
305self._action: batch.action,
306self._next_state: batch.next_state,
307self._initial_state: batch.initial_state,
308self._weights: self._parameters.gamma ** batch.time_step,
309}
310
311# On-policy next action and initial action.
312feed_dict[self._next_action] = target_policy.sample_action(
313batch.next_state)
314feed_dict[self._initial_action] = target_policy.sample_action(
315batch.initial_state)
316
317if self._average_next_nu:
318next_probabilities = target_policy.get_probabilities(batch.next_state)
319feed_dict[self._target_policy_next_probs] = next_probabilities
320
321policy_ratio = policy_lib.get_policy_ratio(baseline_policy, target_policy,
322batch.state, batch.action)
323
324if np.any(np.isnan(policy_ratio)):
325raise ValueError('NaNs encountered in policy ratio: %s.' % policy_ratio)
326feed_dict[self._policy_ratio] = policy_ratio
327
328self._session.run(self._train_op, feed_dict=feed_dict)
329
330if step % self._parameters.log_every == 0:
331debug = self._session.run(self._debug, feed_dict=feed_dict)
332tf.logging.info('At step %d' % step)
333tf.logging.info('Debug: %s' % debug)
334value_estimate = self.estimate_average_reward(data, target_policy)
335tf.logging.info('Estimated value: %s' % value_estimate)
336value_estimates.append(value_estimate)
337tf.logging.info(
338'Estimated smoothed value: %s' %
339np.mean(value_estimates[-self._parameters.smooth_over:]))
340
341if self._parameters.summary_writer:
342summary = tf.Summary(value=[
343tf.Summary.Value(
344tag='%sdebug' % self._parameters.summary_prefix,
345simple_value=debug),
346tf.Summary.Value(
347tag='%svalue_estimate' % self._parameters.summary_prefix,
348simple_value=value_estimate)])
349self._parameters.summary_writer.add_summary(summary, step)
350
351value_estimate = self.estimate_average_reward(data, target_policy)
352tf.logging.info('Estimated value: %s' % value_estimate)
353value_estimates.append(value_estimate)
354tf.logging.info('Estimated smoothed value: %s' %
355np.mean(value_estimates[-self._parameters.smooth_over:]))
356
357# Return estimate that is smoothed over last few iterates.
358return np.mean(value_estimates[-self._parameters.smooth_over:])
359
360def _state_action_density_ratio(self, state, action):
361batched = len(np.shape(state)) > 1
362if not batched:
363state = np.expand_dims(state, 0)
364action = np.expand_dims(action, 0)
365density_ratio = self._session.run(
366self._density_ratio,
367feed_dict={
368self._state: state,
369self._action: action
370})
371if not batched:
372return density_ratio[0]
373return density_ratio
374
375def _state_density_ratio(self, state):
376batched = len(np.shape(state)) > 1
377if not batched:
378state = np.expand_dims(state, 0)
379density_ratio = self._session.run(
380self._density_ratio, feed_dict={self._state: state})
381if not batched:
382return density_ratio[0]
383return density_ratio
384
385def estimate_average_reward(self, data, target_policy):
386"""Estimates value (average per-step reward) of policy.
387
388The estimation is based on solved values of zeta, so one should call
389solve() before calling this function.
390
391Args:
392data: The transition data store to use.
393target_policy: The policy whose value we want to estimate.
394
395Returns:
396Estimated average per-step reward of the target policy.
397"""
398if self._solve_for_state_action_ratio:
399return base_algo.estimate_value_from_state_action_ratios(
400data, self._parameters.gamma, self._state_action_density_ratio)
401else:
402return base_algo.estimate_value_from_state_ratios(
403data, target_policy, self._parameters.gamma,
404self._state_density_ratio)
405
406def close(self):
407tf.reset_default_graph()
408self._session.close()
409