google-research
225 строк · 8.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""ConQUR agent to fine-tune the second-last layer of a Q-network."""
17import collections
18
19from dopamine.agents.dqn import dqn_agent
20import gin
21import numpy as np
22import tensorflow.compat.v1 as tf
23import tf_slim
24
25
26@gin.configurable
27class ConqurAgent(dqn_agent.DQNAgent):
28"""DQN agent with last layer training.
29
30This is a ConQUR Agent that actually does all the heavily lifting
31of the training process and neural network specification.
32"""
33
34def __init__(self, session, num_actions, random_state):
35"""Initializes the agent and constructs the components of its graph.
36
37Args:
38session: tf.Session, for executing ops.
39num_actions: int, number of actions the agent can take at any state.
40random_state: np.random.RandomState, random generator state.
41"""
42self.eval_mode = True
43self.random_state = random_state
44super(ConqurAgent, self).__init__(session, num_actions)
45
46def reload_checkpoint(self, checkpoint_path):
47"""Reload variables from a fully specified checkpoint.
48
49Args:
50checkpoint_path: string, full path to a checkpoint to reload.
51"""
52assert checkpoint_path
53variables_to_restore = tf_slim.get_variables_to_restore()
54reloader = tf.train.Saver(var_list=variables_to_restore)
55reloader.restore(self._sess, checkpoint_path)
56
57var = [
58v for v in variables_to_restore
59if v.name == 'Online/fully_connected_1/weights:0'
60][0]
61wts = self._sess.run(var)
62var = [
63v for v in variables_to_restore
64if v.name == 'Online/fully_connected_1/biases:0'
65][0]
66biases = self._sess.run(var)
67num_wts = wts.size + biases.size
68
69target_var = [
70v for v in variables_to_restore
71if v.name == 'Target/fully_connected_1/weights:0'
72][0]
73target_wts = self._sess.run(target_var)
74target_var = [
75v for v in variables_to_restore
76if v.name == 'Target/fully_connected_1/biases:0'
77][0]
78target_biases = self._sess.run(target_var)
79self.target_wts = target_wts
80self.target_biases = target_biases
81
82self.last_layer_weights = wts
83self.last_layer_biases = biases
84self.last_layer_wts = np.append(wts, np.expand_dims(biases, axis=0), axis=0)
85self.last_layer_wts = self.last_layer_wts.reshape((num_wts,), order='F')
86
87def _get_network_type(self):
88"""Return the type of the outputs of a Q value network.
89
90Returns:
91net_type: _network_type object defining the outputs of the network.
92"""
93return collections.namedtuple('DQN_network', ['q_values'])
94
95def _network_template(self, state):
96"""Builds the convolutional network used to compute the agent's Q-values.
97
98Args:
99state: tf.Placeholder, contains the agent's current state.
100
101Returns:
102net: _network_type object containing the tensors output by the network.
103"""
104net = tf.cast(state, tf.float32)
105net = tf.math.truediv(net, 255.)
106net = tf_slim.conv2d(net, 32, [8, 8], stride=4, trainable=False)
107net = tf_slim.conv2d(net, 64, [4, 4], stride=2, trainable=False)
108net = tf_slim.conv2d(net, 64, [3, 3], stride=1, trainable=False)
109net = tf_slim.flatten(net)
110linear_features = tf_slim.fully_connected(net, 512, trainable=True)
111q_values = tf_slim.fully_connected(
112linear_features, self.num_actions, activation_fn=None)
113return self._get_network_type()(q_values), linear_features
114
115def _create_network(self, name):
116"""Builds the convolutional network used to compute the agent's Q-values.
117
118Args:
119name: str, this name is passed to the tf.keras.Model and used to create
120variable scope under the hood by the tf.keras.Model.
121
122Returns:
123network: tf.keras.Model, the network instantiated by the Keras model.
124"""
125network = self.network(self.num_actions, name=name)
126return network
127
128def _build_networks(self):
129"""Builds the Q-value network computations needed for acting and training.
130
131These are:
132self.online_convnet: For computing the current state's Q-values.
133self.target_convnet: For computing the next state's target Q-values.
134self._net_outputs: The actual Q-values.
135self._q_argmax: The action maximizing the current state's Q-values.
136self._replay_net_outputs: The replayed states' Q-values.
137self._replay_next_target_net_outputs: The replayed next states' target
138Q-values (see Mnih et al., 2015 for details).
139self.linear_features: The linear features from second last layer
140"""
141# Calling online_convnet will generate a new graph as defined in
142# self._get_network_template using whatever input is passed, but will always
143# share the same weights.
144self.online_convnet = tf.make_template('Online', self._network_template)
145self.target_convnet = tf.make_template('Target', self._network_template)
146self._net_outputs, self.linear_features = self.online_convnet(self.state_ph)
147self._next_target_net_outputs_q, self.target_linear_features = self.target_convnet(
148self.state_ph)
149self.next_qt_max = tf.reduce_max(self._next_target_net_outputs_q)
150self.ddqn_replay_next_target_net_outputs, _ = self.online_convnet(
151self._replay.next_states)
152self._q_argmax = tf.argmax(self._net_outputs.q_values, axis=1)[0]
153
154self._replay_net_outputs, _ = self.online_convnet(self._replay.states)
155self._replay_next_target_net_outputs, _ = self.target_convnet(
156self._replay.next_states)
157
158def update_observation(self, observation, reward, is_terminal):
159self._last_observation = self._observation
160self._record_observation(observation)
161
162def _select_action(self):
163if self.random_state.uniform() <= self.eps_action:
164# Choose a random action with probability epsilon.
165return self.random_state.randint(0, self.num_actions - 1)
166else:
167# Choose the action with highest Q-value at the current state.
168return self._sess.run(self._q_argmax, {self.state_ph: self.state})
169
170def step(self):
171return self._select_action()
172
173def observation_to_linear_features(self):
174# This is essentially the observation for the DNN
175return self._sess.run(self.linear_features, {self.state_ph: self.state})
176
177def get_target_q_op(self, reward, is_terminal):
178return self.next_qt_max, self._next_target_net_outputs_q
179
180def get_target_q_label(self, reward, is_terminal):
181next_qt_max, q_all_actions = self._sess.run(
182self.get_target_q_op(reward, is_terminal), {self.state_ph: self.state})
183is_terminal = is_terminal * 1.0
184return reward + self.cumulative_gamma * next_qt_max * (
1851. - is_terminal), q_all_actions
186
187def reset_state(self):
188self._reset_state()
189
190def get_target_q_label_single_target_layer(self, reward, is_terminal, fc):
191target_linear_feature = self._sess.run(self.target_linear_features,
192{self.state_ph: self.state})
193# The state here is the next state
194q_all_actions = fc(target_linear_feature)
195# Raw no reward actions
196q_all_actions_no_ward = q_all_actions
197q_target_max = tf.reduce_max(q_all_actions)
198q_target = reward + self.cumulative_gamma * q_target_max * (1. -
199is_terminal)
200q_all_actions = reward + self.cumulative_gamma * q_all_actions * (
2011. - is_terminal)
202return q_target, q_all_actions, q_all_actions_no_ward
203
204def get_target_q_label_multiple_target_layers(self, reward, is_terminal,
205fc_list, number_actions):
206target_linear_feature = self._sess.run(self.target_linear_features,
207{self.state_ph: self.state})
208# the state here is the next state
209q_all_actions_list = []
210q_all_actions_no_ward_list = []
211q_target_list = []
212for fc in fc_list:
213q_all_actions = fc(target_linear_feature)
214# Raw no reward actions
215q_all_actions_no_ward = q_all_actions
216q_target_max = tf.reduce_max(q_all_actions)
217q_target = reward + self.cumulative_gamma * q_target_max * (1. -
218is_terminal)
219q_all_actions = reward + self.cumulative_gamma * q_all_actions * (
2201. - is_terminal)
221
222q_all_actions_list.append(q_all_actions)
223q_all_actions_no_ward_list.append(q_all_actions_no_ward)
224q_target_list.append(q_target)
225return q_target_list, q_all_actions_list, q_all_actions_no_ward_list
226