google-research
280 строк · 9.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utility for loading the OpenAI Gym Fetch robotics environments."""
17
18import gym19from gym.envs.robotics.fetch import push20from gym.envs.robotics.fetch import reach21import numpy as np22
23
24class FetchReachEnv(reach.FetchReachEnv):25"""Wrapper for the FetchReach environment."""26
27def __init__(self):28super(FetchReachEnv, self).__init__()29self._old_observation_space = self.observation_space30self._new_observation_space = gym.spaces.Box(31low=np.full((20,), -np.inf),32high=np.full((20,), np.inf),33dtype=np.float32)34self.observation_space = self._new_observation_space35
36def reset(self):37self.observation_space = self._old_observation_space38s = super(FetchReachEnv, self).reset()39self.observation_space = self._new_observation_space40return self.observation(s)41
42def step(self, action):43s, _, _, _ = super(FetchReachEnv, self).step(action)44done = False45dist = np.linalg.norm(s['achieved_goal'] - s['desired_goal'])46r = float(dist < 0.05) # Default from Fetch environment.47info = {}48return self.observation(s), r, done, info49
50def observation(self, observation):51start_index = 052end_index = 353goal_pos_1 = observation['achieved_goal']54goal_pos_2 = observation['observation'][start_index:end_index]55assert np.all(goal_pos_1 == goal_pos_2)56s = observation['observation']57g = np.zeros_like(s)58g[start_index:end_index] = observation['desired_goal']59return np.concatenate([s, g]).astype(np.float32)60
61
62class FetchPushEnv(push.FetchPushEnv):63"""Wrapper for the FetchPush environment."""64
65def __init__(self):66super(FetchPushEnv, self).__init__()67self._old_observation_space = self.observation_space68self._new_observation_space = gym.spaces.Box(69low=np.full((50,), -np.inf),70high=np.full((50,), np.inf),71dtype=np.float32)72self.observation_space = self._new_observation_space73
74def reset(self):75self.observation_space = self._old_observation_space76s = super(FetchPushEnv, self).reset()77self.observation_space = self._new_observation_space78return self.observation(s)79
80def step(self, action):81s, _, _, _ = super(FetchPushEnv, self).step(action)82done = False83dist = np.linalg.norm(s['achieved_goal'] - s['desired_goal'])84r = float(dist < 0.05)85info = {}86return self.observation(s), r, done, info87
88def observation(self, observation):89start_index = 390end_index = 691goal_pos_1 = observation['achieved_goal']92goal_pos_2 = observation['observation'][start_index:end_index]93assert np.all(goal_pos_1 == goal_pos_2)94s = observation['observation']95g = np.zeros_like(s)96g[:start_index] = observation['desired_goal']97g[start_index:end_index] = observation['desired_goal']98return np.concatenate([s, g]).astype(np.float32)99
100
101class FetchReachImage(reach.FetchReachEnv):102"""Wrapper for the FetchReach environment with image observations."""103
104def __init__(self):105self._dist = []106self._dist_vec = []107super(FetchReachImage, self).__init__()108self._old_observation_space = self.observation_space109self._new_observation_space = gym.spaces.Box(110low=np.full((64*64*6), 0),111high=np.full((64*64*6), 255),112dtype=np.uint8)113self.observation_space = self._new_observation_space114self.sim.model.geom_rgba[1:5] = 0 # Hide the lasers115
116def reset_metrics(self):117self._dist_vec = []118self._dist = []119
120def reset(self):121if self._dist: # if len(self._dist) > 0, ...122self._dist_vec.append(self._dist)123self._dist = []124
125# generate the new goal image126self.observation_space = self._old_observation_space127s = super(FetchReachImage, self).reset()128self.observation_space = self._new_observation_space129self._goal = s['desired_goal'].copy()130
131for _ in range(10):132hand = s['achieved_goal']133obj = s['desired_goal']134delta = obj - hand135a = np.concatenate([np.clip(10 * delta, -1, 1), [0.0]])136s, _, _, _ = super(FetchReachImage, self).step(a)137
138self._goal_img = self.observation(s)139
140self.observation_space = self._old_observation_space141s = super(FetchReachImage, self).reset()142self.observation_space = self._new_observation_space143img = self.observation(s)144dist = np.linalg.norm(s['achieved_goal'] - self._goal)145self._dist.append(dist)146return np.concatenate([img, self._goal_img])147
148def step(self, action):149s, _, _, _ = super(FetchReachImage, self).step(action)150dist = np.linalg.norm(s['achieved_goal'] - self._goal)151self._dist.append(dist)152done = False153r = float(dist < 0.05)154info = {}155img = self.observation(s)156return np.concatenate([img, self._goal_img]), r, done, info157
158def observation(self, observation):159self.sim.data.site_xpos[0] = 1_000_000160img = self.render(mode='rgb_array', height=64, width=64)161return img.flatten()162
163def _viewer_setup(self):164super(FetchReachImage, self)._viewer_setup()165self.viewer.cam.lookat[Ellipsis] = np.array([1.2, 0.8, 0.5])166self.viewer.cam.distance = 0.8167self.viewer.cam.azimuth = 180168self.viewer.cam.elevation = -30169
170
171class FetchPushImage(push.FetchPushEnv):172"""Wrapper for the FetchPush environment with image observations."""173
174def __init__(self, camera='camera2', start_at_obj=True, rand_y=False):175self._start_at_obj = start_at_obj176self._rand_y = rand_y177self._camera_name = camera178self._dist = []179self._dist_vec = []180super(FetchPushImage, self).__init__()181self._old_observation_space = self.observation_space182self._new_observation_space = gym.spaces.Box(183low=np.full((64*64*6), 0),184high=np.full((64*64*6), 255),185dtype=np.uint8)186self.observation_space = self._new_observation_space187self.sim.model.geom_rgba[1:5] = 0 # Hide the lasers188
189def reset_metrics(self):190self._dist_vec = []191self._dist = []192
193def _move_hand_to_obj(self):194s = super(FetchPushImage, self)._get_obs()195for _ in range(100):196hand = s['observation'][:3]197obj = s['achieved_goal'] + np.array([-0.02, 0.0, 0.0])198delta = obj - hand199if np.linalg.norm(delta) < 0.06:200break201a = np.concatenate([np.clip(delta, -1, 1), [0.0]])202s, _, _, _ = super(FetchPushImage, self).step(a)203
204def reset(self):205if self._dist: # if len(self._dist) > 0 ...206self._dist_vec.append(self._dist)207self._dist = []208
209# generate the new goal image210self.observation_space = self._old_observation_space211s = super(FetchPushImage, self).reset()212self.observation_space = self._new_observation_space213# Randomize object position214for _ in range(8):215super(FetchPushImage, self).step(np.array([-1.0, 0.0, 0.0, 0.0]))216object_qpos = self.sim.data.get_joint_qpos('object0:joint')217if not self._rand_y:218object_qpos[1] = 0.75219self.sim.data.set_joint_qpos('object0:joint', object_qpos)220self._move_hand_to_obj()221self._goal_img = self.observation(s)222block_xyz = self.sim.data.get_joint_qpos('object0:joint')[:3]223if block_xyz[2] < 0.4: # If block has fallen off the table, recurse.224print('Bad reset, recursing.')225return self.reset()226self._goal = block_xyz[:2].copy()227
228self.observation_space = self._old_observation_space229s = super(FetchPushImage, self).reset()230self.observation_space = self._new_observation_space231for _ in range(8):232super(FetchPushImage, self).step(np.array([-1.0, 0.0, 0.0, 0.0]))233object_qpos = self.sim.data.get_joint_qpos('object0:joint')234object_qpos[:2] = np.array([1.15, 0.75])235self.sim.data.set_joint_qpos('object0:joint', object_qpos)236if self._start_at_obj:237self._move_hand_to_obj()238else:239for _ in range(5):240super(FetchPushImage, self).step(self.action_space.sample())241
242block_xyz = self.sim.data.get_joint_qpos('object0:joint')[:3].copy()243img = self.observation(s)244dist = np.linalg.norm(block_xyz[:2] - self._goal)245self._dist.append(dist)246if block_xyz[2] < 0.4: # If block has fallen off the table, recurse.247print('Bad reset, recursing.')248return self.reset()249return np.concatenate([img, self._goal_img])250
251def step(self, action):252s, _, _, _ = super(FetchPushImage, self).step(action)253block_xy = self.sim.data.get_joint_qpos('object0:joint')[:2]254dist = np.linalg.norm(block_xy - self._goal)255self._dist.append(dist)256done = False257r = float(dist < 0.05) # Taken from the original task code.258info = {}259img = self.observation(s)260return np.concatenate([img, self._goal_img]), r, done, info261
262def observation(self, observation):263self.sim.data.site_xpos[0] = 1_000_000264img = self.render(mode='rgb_array', height=64, width=64)265return img.flatten()266
267def _viewer_setup(self):268super(FetchPushImage, self)._viewer_setup()269if self._camera_name == 'camera1':270self.viewer.cam.lookat[Ellipsis] = np.array([1.2, 0.8, 0.4])271self.viewer.cam.distance = 0.9272self.viewer.cam.azimuth = 180273self.viewer.cam.elevation = -40274elif self._camera_name == 'camera2':275self.viewer.cam.lookat[Ellipsis] = np.array([1.25, 0.8, 0.4])276self.viewer.cam.distance = 0.65277self.viewer.cam.azimuth = 90278self.viewer.cam.elevation = -40279else:280raise NotImplementedError281