google-research

Форк
0
/
conqur_agent.py 
225 строк · 8.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""ConQUR agent to fine-tune the second-last layer of a Q-network."""
17
import collections
18

19
from dopamine.agents.dqn import dqn_agent
20
import gin
21
import numpy as np
22
import tensorflow.compat.v1 as tf
23
import tf_slim
24

25

26
@gin.configurable
27
class ConqurAgent(dqn_agent.DQNAgent):
28
  """DQN agent with last layer training.
29

30
  This is a ConQUR Agent that actually does all the heavily lifting
31
  of the training process and neural network specification.
32
  """
33

34
  def __init__(self, session, num_actions, random_state):
35
    """Initializes the agent and constructs the components of its graph.
36

37
    Args:
38
      session: tf.Session, for executing ops.
39
      num_actions: int, number of actions the agent can take at any state.
40
      random_state: np.random.RandomState, random generator state.
41
    """
42
    self.eval_mode = True
43
    self.random_state = random_state
44
    super(ConqurAgent, self).__init__(session, num_actions)
45

46
  def reload_checkpoint(self, checkpoint_path):
47
    """Reload variables from a fully specified checkpoint.
48

49
    Args:
50
      checkpoint_path: string, full path to a checkpoint to reload.
51
    """
52
    assert checkpoint_path
53
    variables_to_restore = tf_slim.get_variables_to_restore()
54
    reloader = tf.train.Saver(var_list=variables_to_restore)
55
    reloader.restore(self._sess, checkpoint_path)
56

57
    var = [
58
        v for v in variables_to_restore
59
        if v.name == 'Online/fully_connected_1/weights:0'
60
    ][0]
61
    wts = self._sess.run(var)
62
    var = [
63
        v for v in variables_to_restore
64
        if v.name == 'Online/fully_connected_1/biases:0'
65
    ][0]
66
    biases = self._sess.run(var)
67
    num_wts = wts.size + biases.size
68

69
    target_var = [
70
        v for v in variables_to_restore
71
        if v.name == 'Target/fully_connected_1/weights:0'
72
    ][0]
73
    target_wts = self._sess.run(target_var)
74
    target_var = [
75
        v for v in variables_to_restore
76
        if v.name == 'Target/fully_connected_1/biases:0'
77
    ][0]
78
    target_biases = self._sess.run(target_var)
79
    self.target_wts = target_wts
80
    self.target_biases = target_biases
81

82
    self.last_layer_weights = wts
83
    self.last_layer_biases = biases
84
    self.last_layer_wts = np.append(wts, np.expand_dims(biases, axis=0), axis=0)
85
    self.last_layer_wts = self.last_layer_wts.reshape((num_wts,), order='F')
86

87
  def _get_network_type(self):
88
    """Return the type of the outputs of a Q value network.
89

90
    Returns:
91
      net_type: _network_type object defining the outputs of the network.
92
    """
93
    return collections.namedtuple('DQN_network', ['q_values'])
94

95
  def _network_template(self, state):
96
    """Builds the convolutional network used to compute the agent's Q-values.
97

98
    Args:
99
      state: tf.Placeholder, contains the agent's current state.
100

101
    Returns:
102
      net: _network_type object containing the tensors output by the network.
103
    """
104
    net = tf.cast(state, tf.float32)
105
    net = tf.math.truediv(net, 255.)
106
    net = tf_slim.conv2d(net, 32, [8, 8], stride=4, trainable=False)
107
    net = tf_slim.conv2d(net, 64, [4, 4], stride=2, trainable=False)
108
    net = tf_slim.conv2d(net, 64, [3, 3], stride=1, trainable=False)
109
    net = tf_slim.flatten(net)
110
    linear_features = tf_slim.fully_connected(net, 512, trainable=True)
111
    q_values = tf_slim.fully_connected(
112
        linear_features, self.num_actions, activation_fn=None)
113
    return self._get_network_type()(q_values), linear_features
114

115
  def _create_network(self, name):
116
    """Builds the convolutional network used to compute the agent's Q-values.
117

118
    Args:
119
      name: str, this name is passed to the tf.keras.Model and used to create
120
        variable scope under the hood by the tf.keras.Model.
121

122
    Returns:
123
      network: tf.keras.Model, the network instantiated by the Keras model.
124
    """
125
    network = self.network(self.num_actions, name=name)
126
    return network
127

128
  def _build_networks(self):
129
    """Builds the Q-value network computations needed for acting and training.
130

131
    These are:
132
      self.online_convnet: For computing the current state's Q-values.
133
      self.target_convnet: For computing the next state's target Q-values.
134
      self._net_outputs: The actual Q-values.
135
      self._q_argmax: The action maximizing the current state's Q-values.
136
      self._replay_net_outputs: The replayed states' Q-values.
137
      self._replay_next_target_net_outputs: The replayed next states' target
138
        Q-values (see Mnih et al., 2015 for details).
139
      self.linear_features: The linear features from second last layer
140
    """
141
    # Calling online_convnet will generate a new graph as defined in
142
    # self._get_network_template using whatever input is passed, but will always
143
    # share the same weights.
144
    self.online_convnet = tf.make_template('Online', self._network_template)
145
    self.target_convnet = tf.make_template('Target', self._network_template)
146
    self._net_outputs, self.linear_features = self.online_convnet(self.state_ph)
147
    self._next_target_net_outputs_q, self.target_linear_features = self.target_convnet(
148
        self.state_ph)
149
    self.next_qt_max = tf.reduce_max(self._next_target_net_outputs_q)
150
    self.ddqn_replay_next_target_net_outputs, _ = self.online_convnet(
151
        self._replay.next_states)
152
    self._q_argmax = tf.argmax(self._net_outputs.q_values, axis=1)[0]
153

154
    self._replay_net_outputs, _ = self.online_convnet(self._replay.states)
155
    self._replay_next_target_net_outputs, _ = self.target_convnet(
156
        self._replay.next_states)
157

158
  def update_observation(self, observation, reward, is_terminal):
159
    self._last_observation = self._observation
160
    self._record_observation(observation)
161

162
  def _select_action(self):
163
    if self.random_state.uniform() <= self.eps_action:
164
      # Choose a random action with probability epsilon.
165
      return self.random_state.randint(0, self.num_actions - 1)
166
    else:
167
      # Choose the action with highest Q-value at the current state.
168
      return self._sess.run(self._q_argmax, {self.state_ph: self.state})
169

170
  def step(self):
171
    return self._select_action()
172

173
  def observation_to_linear_features(self):
174
    # This is essentially the observation for the DNN
175
    return self._sess.run(self.linear_features, {self.state_ph: self.state})
176

177
  def get_target_q_op(self, reward, is_terminal):
178
    return self.next_qt_max, self._next_target_net_outputs_q
179

180
  def get_target_q_label(self, reward, is_terminal):
181
    next_qt_max, q_all_actions = self._sess.run(
182
        self.get_target_q_op(reward, is_terminal), {self.state_ph: self.state})
183
    is_terminal = is_terminal * 1.0
184
    return reward + self.cumulative_gamma * next_qt_max * (
185
        1. - is_terminal), q_all_actions
186

187
  def reset_state(self):
188
    self._reset_state()
189

190
  def get_target_q_label_single_target_layer(self, reward, is_terminal, fc):
191
    target_linear_feature = self._sess.run(self.target_linear_features,
192
                                           {self.state_ph: self.state})
193
    # The state here is the next state
194
    q_all_actions = fc(target_linear_feature)
195
    # Raw no reward actions
196
    q_all_actions_no_ward = q_all_actions
197
    q_target_max = tf.reduce_max(q_all_actions)
198
    q_target = reward + self.cumulative_gamma * q_target_max * (1. -
199
                                                                is_terminal)
200
    q_all_actions = reward + self.cumulative_gamma * q_all_actions * (
201
        1. - is_terminal)
202
    return q_target, q_all_actions, q_all_actions_no_ward
203

204
  def get_target_q_label_multiple_target_layers(self, reward, is_terminal,
205
                                                fc_list, number_actions):
206
    target_linear_feature = self._sess.run(self.target_linear_features,
207
                                           {self.state_ph: self.state})
208
    # the state here is the next state
209
    q_all_actions_list = []
210
    q_all_actions_no_ward_list = []
211
    q_target_list = []
212
    for fc in fc_list:
213
      q_all_actions = fc(target_linear_feature)
214
      # Raw no reward actions
215
      q_all_actions_no_ward = q_all_actions
216
      q_target_max = tf.reduce_max(q_all_actions)
217
      q_target = reward + self.cumulative_gamma * q_target_max * (1. -
218
                                                                  is_terminal)
219
      q_all_actions = reward + self.cumulative_gamma * q_all_actions * (
220
          1. - is_terminal)
221

222
      q_all_actions_list.append(q_all_actions)
223
      q_all_actions_no_ward_list.append(q_all_actions_no_ward)
224
      q_target_list.append(q_target)
225
    return q_target_list, q_all_actions_list, q_all_actions_no_ward_list
226

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.