google-research

Форк
0
227 строк · 7.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for the contrastive RL agent."""
17

18
import functools
19
from typing import Dict, Optional, Sequence
20

21
from acme import types
22
from acme import wrappers
23
from acme.agents.jax import actors
24
from acme.jax import networks as network_lib
25
from acme.jax import utils
26
from acme.utils.observers import base as observers_base
27
from acme.wrappers import base
28
import dm_env
29
import jax
30
import numpy as np
31

32
from cvl_public import env_utils
33

34

35
def obs_to_goal_1d(obs, start_index, end_index):
36
  assert len(obs.shape) == 1
37
  return obs_to_goal_2d(obs[None], start_index, end_index)[0]
38

39

40
def obs_to_goal_2d(obs, start_index, end_index):
41
  assert len(obs.shape) == 2
42
  if end_index == -1:
43
    return obs[:, start_index:]
44
  else:
45
    return obs[:, start_index:end_index]
46

47

48
class SuccessObserver(observers_base.EnvLoopObserver):
49
  """Measures success by whether any of the rewards in an episode are positive.
50
  """
51

52
  def __init__(self):
53
    self._rewards = []
54
    self._success = []
55

56
  def observe_first(self, env, timestep
57
                    ):
58
    """Observes the initial state."""
59
    if self._rewards:
60
      success = np.sum(self._rewards) >= 1
61
      self._success.append(success)
62
    self._rewards = []
63

64
  def observe(self, env, timestep,
65
              action):
66
    """Records one environment step."""
67
    self._rewards.append(timestep.reward)
68

69
  def get_metrics(self):
70
    """Returns metrics collected for the current episode."""
71
    return {
72
        'success': float(np.sum(self._rewards) >= 1),
73
        'success_1000': np.mean(self._success[-1000:]),
74
    }
75

76

77
class DistanceObserver(observers_base.EnvLoopObserver):
78
  """Observer that measures the L2 distance to the goal."""
79

80
  def __init__(self, obs_dim, start_index, end_index,
81
               smooth = True):
82
    self._distances = []
83
    self._obs_dim = obs_dim
84
    self._obs_to_goal = functools.partial(
85
        obs_to_goal_1d, start_index=start_index, end_index=end_index)
86
    self._smooth = smooth
87
    self._history = {}
88

89
  def _get_distance(self, env,
90
                    timestep):
91
    if hasattr(env, '_dist'):
92
      assert env._dist  # pylint: disable=protected-access
93
      return env._dist[-1]  # pylint: disable=protected-access
94
    else:
95
      # Note that the timestep comes from the environment, which has already
96
      # had some goal coordinates removed.
97
      obs = timestep.observation[:self._obs_dim]
98
      goal = timestep.observation[self._obs_dim:]
99
      if self._obs_to_goal(obs).shape == goal.shape:
100
        dist = np.linalg.norm(self._obs_to_goal(obs) - goal)
101
      else:
102
        dist = 0.
103
      return dist
104

105
  def observe_first(self, env, timestep
106
                    ):
107
    """Observes the initial state."""
108
    if self._smooth and self._distances:
109
      for key, value in self._get_current_metrics().items():
110
        self._history[key] = self._history.get(key, []) + [value]
111
    try:
112
      self._distances = [self._get_distance(env, timestep)]
113
    except:  # pylint: disable=bare-except
114
      self._distances = []
115

116
  def observe(self, env, timestep,
117
              action):
118
    """Records one environment step."""
119
    self._distances.append(self._get_distance(env, timestep))
120

121
  def _get_current_metrics(self):
122
    metrics = {
123
        'init_dist': self._distances[0],
124
        'final_dist': self._distances[-1],
125
        'delta_dist': self._distances[0] - self._distances[-1],
126
        'min_dist': min(self._distances),
127
    }
128
    return metrics
129

130
  def get_metrics(self):
131
    """Returns metrics collected for the current episode."""
132
    metrics = self._get_current_metrics()
133
    if self._smooth:
134
      for key, vec in self._history.items():
135
        for size in [10, 100, 1000]:
136
          metrics['%s_%d' % (key, size)] = np.nanmean(vec[-size:])
137
    return metrics
138

139

140
class ObservationFilterWrapper(base.EnvironmentWrapper):
141
  """Wrapper that exposes just the desired goal coordinates."""
142

143
  def __init__(self, environment,
144
               idx):
145
    """Initializes a new ObservationFilterWrapper.
146

147
    Args:
148
      environment: Environment to wrap.
149
      idx: Sequence of indices of coordinates to keep.
150
    """
151
    super().__init__(environment)
152
    self._idx = idx
153
    observation_spec = environment.observation_spec()
154
    spec_min = self._convert_observation(observation_spec.minimum)
155
    spec_max = self._convert_observation(observation_spec.maximum)
156
    self._observation_spec = dm_env.specs.BoundedArray(
157
        shape=spec_min.shape,
158
        dtype=spec_min.dtype,
159
        minimum=spec_min,
160
        maximum=spec_max,
161
        name='state')
162

163
  def _convert_observation(self, observation):
164
    return observation[self._idx]
165

166
  def step(self, action):
167
    timestep = self._environment.step(action)
168
    return timestep._replace(
169
        observation=self._convert_observation(timestep.observation))
170

171
  def reset(self):
172
    timestep = self._environment.reset()
173
    return timestep._replace(
174
        observation=self._convert_observation(timestep.observation))
175

176
  def observation_spec(self):
177
    return self._observation_spec
178

179

180
def make_environment(env_name, start_index, end_index,
181
                     seed):
182
  """Creates the environment.
183

184
  Args:
185
    env_name: name of the environment
186
    start_index: first index of the observation to use in the goal.
187
    end_index: final index of the observation to use in the goal. The goal
188
      is then obs[start_index:goal_index].
189
    seed: random seed.
190
  Returns:
191
    env: the environment
192
    obs_dim: integer specifying the size of the observations, before
193
      the start_index/end_index is applied.
194
  """
195
  np.random.seed(seed)
196
  gym_env, obs_dim, max_episode_steps, is_gc = env_utils.load(env_name)
197
  if is_gc:
198
    goal_indices = obs_dim + obs_to_goal_1d(np.arange(obs_dim), start_index,
199
                                            end_index)
200
    indices = np.concatenate([
201
        np.arange(obs_dim),
202
        goal_indices
203
    ])
204
  env = wrappers.GymWrapper(gym_env)
205
  env = wrappers.StepLimitWrapper(env, step_limit=max_episode_steps)
206
  if is_gc:
207
    env = ObservationFilterWrapper(env, indices)
208
  if env_name.startswith('ant_'):
209
    env = wrappers.CanonicalSpecWrapper(env)
210
  return env, obs_dim
211

212

213
class InitiallyRandomActor(actors.GenericActor):
214
  """Actor that takes actions uniformly at random until the actor is updated.
215
  """
216

217
  def select_action(self,
218
                    observation):
219
    if (self._params['mlp/~/linear_0']['b'] == 0).all():
220
      shape = self._params['Normal/~/linear']['b'].shape
221
      rng, self._state = jax.random.split(self._state)
222
      action = jax.random.uniform(key=rng, shape=shape,
223
                                  minval=-1.0, maxval=1.0)
224
    else:
225
      action, self._state = self._policy(self._params, observation,
226
                                         self._state)
227
    return utils.to_numpy(action)
228

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.