google-research
186 строк · 5.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Imitation loop for PWIL."""
17
18import time
19
20import acme
21from acme.utils import counting
22from acme.utils import loggers
23import dm_env
24
25
26class TrainEnvironmentLoop(acme.core.Worker):
27"""PWIL environment loop.
28
29This takes `Environment` and `Actor` instances and coordinates their
30interaction. This can be used as:
31
32loop = TrainEnvironmentLoop(environment, actor, rewarder)
33loop.run(num_steps)
34
35The `Rewarder` overwrites the timestep from the environment to define
36a custom reward.
37
38The runner stores episode rewards and a series of statistics in the provided
39`Logger`.
40"""
41
42def __init__(
43self,
44environment,
45actor,
46rewarder,
47counter=None,
48logger=None
49):
50self._environment = environment
51self._actor = actor
52self._rewarder = rewarder
53self._counter = counter or counting.Counter()
54self._logger = logger or loggers.make_default_logger()
55
56def run(self, num_steps):
57"""Perform the run loop.
58
59Args:
60num_steps: number of steps to run the loop for.
61"""
62current_steps = 0
63while current_steps < num_steps:
64
65# Reset any counts and start the environment.
66start_time = time.time()
67self._rewarder.reset()
68
69episode_steps = 0
70episode_return = 0
71episode_imitation_return = 0
72timestep = self._environment.reset()
73
74self._actor.observe_first(timestep)
75
76# Run an episode.
77while not timestep.last():
78action = self._actor.select_action(timestep.observation)
79obs_act = {'observation': timestep.observation, 'action': action}
80imitation_reward = self._rewarder.compute_reward(obs_act)
81timestep = self._environment.step(action)
82imitation_timestep = dm_env.TimeStep(step_type=timestep.step_type,
83reward=imitation_reward,
84discount=timestep.discount,
85observation=timestep.observation)
86
87self._actor.observe(action, next_timestep=imitation_timestep)
88self._actor.update()
89
90# Book-keeping.
91episode_steps += 1
92episode_return += timestep.reward
93episode_imitation_return += imitation_reward
94
95# Collect the results and combine with counts.
96counts = self._counter.increment(episodes=1, steps=episode_steps)
97steps_per_second = episode_steps / (time.time() - start_time)
98result = {
99'episode_length': episode_steps,
100'episode_return': episode_return,
101'episode_return_imitation': episode_imitation_return,
102'steps_per_second': steps_per_second,
103}
104result.update(counts)
105
106self._logger.write(result)
107current_steps += episode_steps
108
109
110class EvalEnvironmentLoop(acme.core.Worker):
111"""PWIL evaluation environment loop.
112
113This takes `Environment` and `Actor` instances and coordinates their
114interaction. This can be used as:
115
116loop = EvalEnvironmentLoop(environment, actor, rewarder)
117loop.run(num_episodes)
118
119The `Rewarder` overwrites the timestep from the environment to define
120a custom reward. The evaluation environment loop does not update the agent,
121and computes the wasserstein distance with expert demonstrations.
122
123The runner stores episode rewards and a series of statistics in the provided
124`Logger`.
125"""
126
127def __init__(
128self,
129environment,
130actor,
131rewarder,
132counter=None,
133logger=None
134):
135self._environment = environment
136self._actor = actor
137self._rewarder = rewarder
138self._counter = counter or counting.Counter()
139self._logger = logger or loggers.make_default_logger()
140
141def run(self, num_episodes):
142"""Perform the run loop.
143
144Args:
145num_episodes: number of episodes to run the loop for.
146"""
147for _ in range(num_episodes):
148# Reset any counts and start the environment.
149start_time = time.time()
150self._rewarder.reset()
151
152episode_steps = 0
153episode_return = 0
154episode_imitation_return = 0
155timestep = self._environment.reset()
156
157# Run an episode.
158trajectory = []
159while not timestep.last():
160action = self._actor.select_action(timestep.observation)
161obs_act = {'observation': timestep.observation, 'action': action}
162trajectory.append(obs_act)
163imitation_reward = self._rewarder.compute_reward(obs_act)
164
165timestep = self._environment.step(action)
166
167# Book-keeping.
168episode_steps += 1
169episode_return += timestep.reward
170episode_imitation_return += imitation_reward
171
172counts = self._counter.increment(episodes=1, steps=episode_steps)
173w2_dist = self._rewarder.compute_w2_dist_to_expert(trajectory)
174
175# Collect the results and combine with counts.
176steps_per_second = episode_steps / (time.time() - start_time)
177result = {
178'episode_length': episode_steps,
179'episode_return': episode_return,
180'episode_wasserstein_distance': w2_dist,
181'episode_return_imitation': episode_imitation_return,
182'steps_per_second': steps_per_second,
183}
184result.update(counts)
185
186self._logger.write(result)
187