google-research
227 строк · 7.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for the contrastive RL agent."""
17
18import functools19from typing import Dict, Optional, Sequence20
21from acme import types22from acme import wrappers23from acme.agents.jax import actors24from acme.jax import networks as network_lib25from acme.jax import utils26from acme.utils.observers import base as observers_base27from acme.wrappers import base28import dm_env29import jax30import numpy as np31
32from cvl_public import env_utils33
34
35def obs_to_goal_1d(obs, start_index, end_index):36assert len(obs.shape) == 137return obs_to_goal_2d(obs[None], start_index, end_index)[0]38
39
40def obs_to_goal_2d(obs, start_index, end_index):41assert len(obs.shape) == 242if end_index == -1:43return obs[:, start_index:]44else:45return obs[:, start_index:end_index]46
47
48class SuccessObserver(observers_base.EnvLoopObserver):49"""Measures success by whether any of the rewards in an episode are positive.50"""
51
52def __init__(self):53self._rewards = []54self._success = []55
56def observe_first(self, env, timestep57):58"""Observes the initial state."""59if self._rewards:60success = np.sum(self._rewards) >= 161self._success.append(success)62self._rewards = []63
64def observe(self, env, timestep,65action):66"""Records one environment step."""67self._rewards.append(timestep.reward)68
69def get_metrics(self):70"""Returns metrics collected for the current episode."""71return {72'success': float(np.sum(self._rewards) >= 1),73'success_1000': np.mean(self._success[-1000:]),74}75
76
77class DistanceObserver(observers_base.EnvLoopObserver):78"""Observer that measures the L2 distance to the goal."""79
80def __init__(self, obs_dim, start_index, end_index,81smooth = True):82self._distances = []83self._obs_dim = obs_dim84self._obs_to_goal = functools.partial(85obs_to_goal_1d, start_index=start_index, end_index=end_index)86self._smooth = smooth87self._history = {}88
89def _get_distance(self, env,90timestep):91if hasattr(env, '_dist'):92assert env._dist # pylint: disable=protected-access93return env._dist[-1] # pylint: disable=protected-access94else:95# Note that the timestep comes from the environment, which has already96# had some goal coordinates removed.97obs = timestep.observation[:self._obs_dim]98goal = timestep.observation[self._obs_dim:]99if self._obs_to_goal(obs).shape == goal.shape:100dist = np.linalg.norm(self._obs_to_goal(obs) - goal)101else:102dist = 0.103return dist104
105def observe_first(self, env, timestep106):107"""Observes the initial state."""108if self._smooth and self._distances:109for key, value in self._get_current_metrics().items():110self._history[key] = self._history.get(key, []) + [value]111try:112self._distances = [self._get_distance(env, timestep)]113except: # pylint: disable=bare-except114self._distances = []115
116def observe(self, env, timestep,117action):118"""Records one environment step."""119self._distances.append(self._get_distance(env, timestep))120
121def _get_current_metrics(self):122metrics = {123'init_dist': self._distances[0],124'final_dist': self._distances[-1],125'delta_dist': self._distances[0] - self._distances[-1],126'min_dist': min(self._distances),127}128return metrics129
130def get_metrics(self):131"""Returns metrics collected for the current episode."""132metrics = self._get_current_metrics()133if self._smooth:134for key, vec in self._history.items():135for size in [10, 100, 1000]:136metrics['%s_%d' % (key, size)] = np.nanmean(vec[-size:])137return metrics138
139
140class ObservationFilterWrapper(base.EnvironmentWrapper):141"""Wrapper that exposes just the desired goal coordinates."""142
143def __init__(self, environment,144idx):145"""Initializes a new ObservationFilterWrapper.146
147Args:
148environment: Environment to wrap.
149idx: Sequence of indices of coordinates to keep.
150"""
151super().__init__(environment)152self._idx = idx153observation_spec = environment.observation_spec()154spec_min = self._convert_observation(observation_spec.minimum)155spec_max = self._convert_observation(observation_spec.maximum)156self._observation_spec = dm_env.specs.BoundedArray(157shape=spec_min.shape,158dtype=spec_min.dtype,159minimum=spec_min,160maximum=spec_max,161name='state')162
163def _convert_observation(self, observation):164return observation[self._idx]165
166def step(self, action):167timestep = self._environment.step(action)168return timestep._replace(169observation=self._convert_observation(timestep.observation))170
171def reset(self):172timestep = self._environment.reset()173return timestep._replace(174observation=self._convert_observation(timestep.observation))175
176def observation_spec(self):177return self._observation_spec178
179
180def make_environment(env_name, start_index, end_index,181seed):182"""Creates the environment.183
184Args:
185env_name: name of the environment
186start_index: first index of the observation to use in the goal.
187end_index: final index of the observation to use in the goal. The goal
188is then obs[start_index:goal_index].
189seed: random seed.
190Returns:
191env: the environment
192obs_dim: integer specifying the size of the observations, before
193the start_index/end_index is applied.
194"""
195np.random.seed(seed)196gym_env, obs_dim, max_episode_steps, is_gc = env_utils.load(env_name)197if is_gc:198goal_indices = obs_dim + obs_to_goal_1d(np.arange(obs_dim), start_index,199end_index)200indices = np.concatenate([201np.arange(obs_dim),202goal_indices
203])204env = wrappers.GymWrapper(gym_env)205env = wrappers.StepLimitWrapper(env, step_limit=max_episode_steps)206if is_gc:207env = ObservationFilterWrapper(env, indices)208if env_name.startswith('ant_'):209env = wrappers.CanonicalSpecWrapper(env)210return env, obs_dim211
212
213class InitiallyRandomActor(actors.GenericActor):214"""Actor that takes actions uniformly at random until the actor is updated.215"""
216
217def select_action(self,218observation):219if (self._params['mlp/~/linear_0']['b'] == 0).all():220shape = self._params['Normal/~/linear']['b'].shape221rng, self._state = jax.random.split(self._state)222action = jax.random.uniform(key=rng, shape=shape,223minval=-1.0, maxval=1.0)224else:225action, self._state = self._policy(self._params, observation,226self._state)227return utils.to_numpy(action)228