google-research

Форк
0
/
imitation_loop.py 
186 строк · 5.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Imitation loop for PWIL."""
17

18
import time
19

20
import acme
21
from acme.utils import counting
22
from acme.utils import loggers
23
import dm_env
24

25

26
class TrainEnvironmentLoop(acme.core.Worker):
27
  """PWIL environment loop.
28

29
  This takes `Environment` and `Actor` instances and coordinates their
30
  interaction. This can be used as:
31

32
    loop = TrainEnvironmentLoop(environment, actor, rewarder)
33
    loop.run(num_steps)
34

35
  The `Rewarder` overwrites the timestep from the environment to define
36
  a custom reward.
37

38
  The runner stores episode rewards and a series of statistics in the provided
39
  `Logger`.
40
  """
41

42
  def __init__(
43
      self,
44
      environment,
45
      actor,
46
      rewarder,
47
      counter=None,
48
      logger=None
49
  ):
50
    self._environment = environment
51
    self._actor = actor
52
    self._rewarder = rewarder
53
    self._counter = counter or counting.Counter()
54
    self._logger = logger or loggers.make_default_logger()
55

56
  def run(self, num_steps):
57
    """Perform the run loop.
58

59
    Args:
60
      num_steps: number of steps to run the loop for.
61
    """
62
    current_steps = 0
63
    while current_steps < num_steps:
64

65
      # Reset any counts and start the environment.
66
      start_time = time.time()
67
      self._rewarder.reset()
68

69
      episode_steps = 0
70
      episode_return = 0
71
      episode_imitation_return = 0
72
      timestep = self._environment.reset()
73

74
      self._actor.observe_first(timestep)
75

76
      # Run an episode.
77
      while not timestep.last():
78
        action = self._actor.select_action(timestep.observation)
79
        obs_act = {'observation': timestep.observation, 'action': action}
80
        imitation_reward = self._rewarder.compute_reward(obs_act)
81
        timestep = self._environment.step(action)
82
        imitation_timestep = dm_env.TimeStep(step_type=timestep.step_type,
83
                                             reward=imitation_reward,
84
                                             discount=timestep.discount,
85
                                             observation=timestep.observation)
86

87
        self._actor.observe(action, next_timestep=imitation_timestep)
88
        self._actor.update()
89

90
        # Book-keeping.
91
        episode_steps += 1
92
        episode_return += timestep.reward
93
        episode_imitation_return += imitation_reward
94

95
      # Collect the results and combine with counts.
96
      counts = self._counter.increment(episodes=1, steps=episode_steps)
97
      steps_per_second = episode_steps / (time.time() - start_time)
98
      result = {
99
          'episode_length': episode_steps,
100
          'episode_return': episode_return,
101
          'episode_return_imitation': episode_imitation_return,
102
          'steps_per_second': steps_per_second,
103
      }
104
      result.update(counts)
105

106
      self._logger.write(result)
107
      current_steps += episode_steps
108

109

110
class EvalEnvironmentLoop(acme.core.Worker):
111
  """PWIL evaluation environment loop.
112

113
  This takes `Environment` and `Actor` instances and coordinates their
114
  interaction. This can be used as:
115

116
    loop = EvalEnvironmentLoop(environment, actor, rewarder)
117
    loop.run(num_episodes)
118

119
  The `Rewarder` overwrites the timestep from the environment to define
120
  a custom reward. The evaluation environment loop does not update the agent,
121
  and computes the wasserstein distance with expert demonstrations.
122

123
  The runner stores episode rewards and a series of statistics in the provided
124
  `Logger`.
125
  """
126

127
  def __init__(
128
      self,
129
      environment,
130
      actor,
131
      rewarder,
132
      counter=None,
133
      logger=None
134
  ):
135
    self._environment = environment
136
    self._actor = actor
137
    self._rewarder = rewarder
138
    self._counter = counter or counting.Counter()
139
    self._logger = logger or loggers.make_default_logger()
140

141
  def run(self, num_episodes):
142
    """Perform the run loop.
143

144
    Args:
145
      num_episodes: number of episodes to run the loop for.
146
    """
147
    for _ in range(num_episodes):
148
      # Reset any counts and start the environment.
149
      start_time = time.time()
150
      self._rewarder.reset()
151

152
      episode_steps = 0
153
      episode_return = 0
154
      episode_imitation_return = 0
155
      timestep = self._environment.reset()
156

157
      # Run an episode.
158
      trajectory = []
159
      while not timestep.last():
160
        action = self._actor.select_action(timestep.observation)
161
        obs_act = {'observation': timestep.observation, 'action': action}
162
        trajectory.append(obs_act)
163
        imitation_reward = self._rewarder.compute_reward(obs_act)
164

165
        timestep = self._environment.step(action)
166

167
        # Book-keeping.
168
        episode_steps += 1
169
        episode_return += timestep.reward
170
        episode_imitation_return += imitation_reward
171

172
      counts = self._counter.increment(episodes=1, steps=episode_steps)
173
      w2_dist = self._rewarder.compute_w2_dist_to_expert(trajectory)
174

175
      # Collect the results and combine with counts.
176
      steps_per_second = episode_steps / (time.time() - start_time)
177
      result = {
178
          'episode_length': episode_steps,
179
          'episode_return': episode_return,
180
          'episode_wasserstein_distance': w2_dist,
181
          'episode_return_imitation': episode_imitation_return,
182
          'steps_per_second': steps_per_second,
183
      }
184
      result.update(counts)
185

186
      self._logger.write(result)
187

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.