google-research
644 строки · 20.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Environments for experiments with RCE.
17"""
18
19import inspect20import os21
22from absl import logging23import d4rl # pylint: disable=unused-import24import gin25import gym26from metaworld.envs.mujoco import sawyer_xyz27import numpy as np28from tf_agents.environments import suite_gym29from tf_agents.environments import tf_py_environment30import tqdm31
32# We need to import d4rl so that gym registers the environments.
33os.environ['SDL_VIDEODRIVER'] = 'dummy'34
35
36def _get_image_obs(self):37# The observation returned here should be in [0, 255].38obs = self.get_image(width=84, height=84)39return obs[::-1]40
41
42@gin.configurable43def load_env(env_name, max_episode_steps=None):44"""Loads an environment.45
46Args:
47env_name: Name of the environment.
48max_episode_steps: Maximum number of steps per episode.
49Returns:
50tf_env: A TFPyEnvironment.
51"""
52if env_name == 'sawyer_reach':53gym_env = SawyerReach()54max_episode_steps = 5155elif env_name == 'sawyer_push':56gym_env = SawyerPush()57max_episode_steps = 15158elif env_name == 'sawyer_lift':59gym_env = SawyerLift()60max_episode_steps = 15161elif env_name == 'sawyer_drawer_open':62gym_env = SawyerDrawerOpen()63max_episode_steps = 15164elif env_name == 'sawyer_drawer_close':65gym_env = SawyerDrawerClose()66max_episode_steps = 15167elif env_name == 'sawyer_box_close':68gym_env = SawyerBoxClose()69max_episode_steps = 15170elif env_name == 'sawyer_bin_picking':71gym_env = SawyerBinPicking()72max_episode_steps = 15173else:74gym_spec = gym.spec(env_name)75gym_env = gym_spec.make()76max_episode_steps = gym_spec.max_episode_steps77
78env = suite_gym.wrap_env(79gym_env,80max_episode_steps=max_episode_steps)81tf_env = tf_py_environment.TFPyEnvironment(env)82return tf_env83
84
85@gin.configurable(denylist=['env', 'env_name'])86def get_data(env, env_name, num_expert_obs=200, terminal_offset=50):87"""Loads the success examples.88
89Args:
90env: A PyEnvironment for which we want to generate success examples.
91env_name: The name of the environment.
92num_expert_obs: The number of success examples to generate.
93terminal_offset: For the d4rl datasets, we randomly subsample the last N
94steps to use as success examples. The terminal_offset parameter is N.
95Returns:
96expert_obs: Array with the success examples.
97"""
98if env_name in ['hammer-human-v0', 'door-human-v0', 'relocate-human-v0']:99dataset = env.get_dataset()100terminals = np.where(dataset['terminals'])[0]101expert_obs = np.concatenate(102[dataset['observations'][t - terminal_offset:t] for t in terminals],103axis=0)104indices = np.random.choice(105len(expert_obs), size=num_expert_obs, replace=False)106expert_obs = expert_obs[indices]107else:108# For environments where we generate the expert dataset on the fly, we can109# improve performance but only generating the number of expert110# observations that we'll actually use. Not all environments support this111# function, so we first have to check whether the environment's112# get_dataset method accepts a num_obs kwarg.113get_dataset_args = inspect.getfullargspec(env.get_dataset).args114if 'num_obs' in get_dataset_args:115dataset = env.get_dataset(num_obs=num_expert_obs)116else:117dataset = env.get_dataset()118indices = np.random.choice(119dataset['observations'].shape[0], size=num_expert_obs, replace=False)120expert_obs = dataset['observations'][indices]121if 'image' in env_name:122expert_obs = expert_obs.astype(np.uint8)123logging.info('Done loading expert observations')124return expert_obs125
126
127class SawyerReach(sawyer_xyz.SawyerReachPushPickPlaceEnv):128"""A simple reaching environment."""129
130def __init__(self):131super(SawyerReach, self).__init__(task_type='reach')132self.initialize_camera(self.init_camera)133
134def step(self, action):135obs = self._get_obs()136d_before = np.linalg.norm(obs[:3] - obs[3:])137s, r, done, info = super(SawyerReach, self).step(action)138d_after = np.linalg.norm(s[:3] - s[3:])139r = d_before - d_after140done = False141return s, r, done, info142
143@gin.configurable(module='SawyerReach')144def reset(self, random=False, width=1.0, random_color=False,145random_size=False):146if random_color:147geom_id = self.model.geom_name2id('objGeom')148rgb = np.random.uniform(np.zeros(3), np.ones(3))149rgba = np.concatenate([rgb, [1.0]])150self.model.geom_rgba[geom_id, :] = rgba151if random_size:152geom_id = self.model.geom_name2id('objGeom')153low = np.array([0.01, 0.005, 0.0])154high = np.array([0.05, 0.045, 0.0])155size = np.random.uniform(low, high)156self.model.geom_size[geom_id, :] = size157super(SawyerReach, self).reset()158
159if random:160low = np.array([-0.2, 0.4, 0.02])161high = np.array([0.2, 0.8, 0.02])162if width == 1:163scaled_low = low164scaled_high = high165else:166mean = (low + high) / 2.0167scaled_low = mean - width * (mean - low)168scaled_high = mean + width * (high - mean)169puck_pos = np.random.uniform(low=scaled_low, high=scaled_high)170self._set_obj_xyz_quat(puck_pos, 0.0)171
172# Hide the default goals and other markers. We use the puck position as173# the goal. This must happen after self._set_obj_xyz_quat(...).174self._state_goal = 10 * np.ones(3)175self._set_goal_marker(self._state_goal)176return self._get_obs()177
178def _get_expert_obs(self):179self.reset()180# Don't use the observation returned from self.reset because this will be181# an image for SawyerReachImage.182obs = self._get_obs()183self.data.set_mocap_pos('mocap', obs[3:6])184self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))185for _ in range(10):186self.do_simulation([-1, 1], self.frame_skip)187# Hide the markers, which get reset after every simulation step.188self._set_goal_marker(self._state_goal)189return self._get_obs()190
191@gin.configurable(module='SawyerReach')192def init_camera(self, camera, mode='default'):193if mode == 'human':194camera.distance = 0.5195camera.lookat[0] = 0.6196camera.lookat[1] = 1.0197camera.lookat[2] = 0.5198camera.elevation = -20199camera.azimuth = 230200camera.trackbodyid = -1201elif mode == 'default':202camera.lookat[0] = 0203camera.lookat[1] = 0.85204camera.lookat[2] = 0.3205camera.distance = 0.4206camera.elevation = -35207camera.azimuth = 270208camera.trackbodyid = -1209elif mode == 'v2':210camera.lookat[0] = 0211camera.lookat[1] = 0.6212camera.lookat[2] = 0.0213camera.distance = 0.7214camera.elevation = -35215camera.azimuth = 180216camera.trackbodyid = -1217else:218raise NotImplementedError219
220def get_dataset(self, num_obs=256):221# This generates examples at ~145 observations / sec. When using image222# observations is slows down to ~17 FPS.223action_vec = [self.action_space.sample() for _ in range(num_obs)]224obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]225dataset = {226'observations': np.array(obs_vec, dtype=np.float32),227'actions': np.array(action_vec, dtype=np.float32),228'rewards': np.zeros(num_obs, dtype=np.float32),229}230return dataset231
232
233class SawyerPush(sawyer_xyz.SawyerReachPushPickPlaceEnv):234"""A pushing environment."""235
236def __init__(self):237super(SawyerPush, self).__init__(task_type='push')238self.initialize_camera(self.init_camera)239self._goal = np.array([0.1, 0.85, 0.02])240
241def step(self, action):242obs = self._get_obs()243d_before = np.linalg.norm(obs[3:] - self._goal)244s, _, done, info = super(SawyerPush, self).step(action)245d_after = np.linalg.norm(s[3:] - self._goal)246r = d_before - d_after247done = False248return s, r, done, info249
250@gin.configurable(module='SawyerPush')251def _get_expert_obs(self, hand_at_puck=True, wide=False, off_table=False):252self.reset()253if wide:254puck_pos = np.random.uniform(low=[-0.15, 0.8, 0.02],255high=[0.15, 0.9, 0.02])256else:257puck_pos = np.random.uniform(low=[0.05, 0.8, 0.02],258high=[0.15, 0.9, 0.02])259if off_table:260assert not wide261assert not hand_at_puck262puck_pos = 10 * np.ones(3,)263self._set_obj_xyz_quat(puck_pos, 0.0)264if hand_at_puck:265hand_goal = puck_pos266else:267hand_goal = np.random.uniform(low=[-0.2, 0.4, 0.02],268high=[0.2, 0.8, 0.3])269self.data.set_mocap_pos('mocap', hand_goal)270self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))271for _ in range(10):272self.do_simulation([-1, 1], self.frame_skip)273return self._get_obs()274
275@gin.configurable(module='SawyerPush')276def init_camera(self, camera, mode='default'):277if mode == 'human':278camera.distance = 0.5279camera.lookat[0] = 0.6280camera.lookat[1] = 1.0281camera.lookat[2] = 0.5282camera.elevation = -20283camera.azimuth = 230284camera.trackbodyid = -1285elif mode == 'default':286camera.lookat[0] = 0287camera.lookat[1] = 0.9288camera.lookat[2] = 0.3289camera.distance = 0.4290camera.elevation = -45291camera.azimuth = 270292camera.trackbodyid = -1293elif mode == 'front':294camera.lookat[0] = 0295camera.lookat[1] = 0.85296camera.lookat[2] = 0.05297camera.distance = 0.4298camera.elevation = 0299camera.azimuth = 270300camera.trackbodyid = -1301elif mode == 'side':302camera.lookat[0] = 0303camera.lookat[1] = 0.7304camera.lookat[2] = 0.05305camera.distance = 0.6306camera.elevation = 0307camera.azimuth = 180308camera.trackbodyid = -1309else:310raise NotImplementedError311
312def get_dataset(self, num_obs=256):313# This generates examples at ~145 observations / sec.314action_vec = [self.action_space.sample() for _ in range(num_obs)]315obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]316dataset = {317'observations': np.array(obs_vec, dtype=np.float32),318'actions': np.array(action_vec, dtype=np.float32),319'rewards': np.zeros(num_obs, dtype=np.float32),320}321return dataset322
323
324class SawyerLift(sawyer_xyz.SawyerReachPushPickPlaceEnv):325"""A task of lifting up an object."""326
327MODE = 'train'328
329def __init__(self):330super(SawyerLift, self).__init__(task_type='reach')331self.initialize_camera(self.init_camera)332
333def _get_dist(self, z):334min_height, max_height = self.target_height()335d_above = abs(z - max_height)336d_below = abs(z - min_height)337if min_height <= z <= max_height:338return 0.0339else:340return min(d_above, d_below)341
342@gin.configurable(module='SawyerLift')343def target_height(self, target_height=0.1):344"""Values 0.1 through 0.3 are reasonable."""345if isinstance(target_height, tuple) or isinstance(target_height, list):346min_height, max_height = target_height347else:348min_height = target_height - 0.02349max_height = target_height + 0.02350return (min_height, max_height)351
352def step(self, action):353obs = self._get_obs()354d_before = self._get_dist(obs[5])355# d_before = abs(obs[5] - self.target_height())356s, r, done, info = super(SawyerLift, self).step(action)357d_after = self._get_dist(s[5])358# d_after = abs(s[5] - self.target_height())359
360r = d_before - d_after361done = False362return s, r, done, info363
364def init_camera(self, camera):365camera.distance = 0.5366camera.lookat[0] = 0.6367camera.lookat[1] = 1.0368camera.lookat[2] = 0.5369camera.elevation = -20370camera.azimuth = 230371camera.trackbodyid = -1372
373@gin.configurable(module='SawyerLift')374def reset(self, reset_to_goal=False):375super(SawyerLift, self).reset()376if reset_to_goal and self.MODE == 'train':377self._get_expert_obs(reset=False)378return self._get_obs()379
380@gin.configurable(module='SawyerLift')381def _get_expert_obs(self, reset=True):382if reset:383self.reset()384obs = self._get_obs()385puck_pos = obs[3:6]386min_height, max_height = self.target_height()387puck_pos[-1] = np.random.uniform(min_height, max_height)388self.data.set_mocap_pos('mocap', puck_pos)389self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))390for _ in range(10):391self.do_simulation([-1, 1], self.frame_skip)392# We have to set the puck position after moving the arm. Otherwise393# the puck will fall while setting the arm position.394self._set_obj_xyz_quat(puck_pos, 0.0)395return self._get_obs()396
397def get_dataset(self, num_obs=256):398# This generates examples at ~145 observations / sec.399action_vec = [self.action_space.sample() for _ in range(num_obs)]400obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]401dataset = {402'observations': np.array(obs_vec, dtype=np.float32),403'actions': np.array(action_vec, dtype=np.float32),404'rewards': np.zeros(num_obs, dtype=np.float32),405}406return dataset407
408
409class SawyerDrawerOpen(sawyer_xyz.SawyerDrawerOpenEnv):410"""A drawer opening task."""411
412def __init__(self):413super(SawyerDrawerOpen, self).__init__()414self.initialize_camera(self.init_camera)415
416def step(self, action):417obs = self._get_obs()418d_before = np.linalg.norm(obs[4] - self.goal[1])419s, r, done, info = super(SawyerDrawerOpen, self).step(action)420d_after = np.linalg.norm(s[4] - self.goal[1])421r = d_before - d_after422done = False423return s, r, done, info424
425def init_camera(self, camera):426camera.distance = 1.427camera.lookat[0] = 0.0428camera.lookat[1] = 0.4429camera.lookat[2] = 0.3430camera.elevation = -20431camera.azimuth = 160432camera.trackbodyid = -1433
434@gin.configurable(module='SawyerDrawerOpen')435def _get_expert_obs(self, hand_at_goal=True):436self.reset()437pos = np.random.uniform(-0.25, -0.15)438self._set_obj_xyz(pos)439if hand_at_goal:440hand_goal = self._get_obs()[3:]441else:442hand_goal = np.random.uniform(low=[-0.2, 0.4, 0.02],443high=[0.2, 0.8, 0.3])444
445self.data.set_mocap_pos('mocap', hand_goal)446self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))447for _ in range(10):448self.do_simulation([-1, 1], self.frame_skip)449return self._get_obs()450
451def get_dataset(self, num_obs=256):452# This generates examples at ~145 observations / sec.453action_vec = [self.action_space.sample() for _ in range(num_obs)]454obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]455dataset = {456'observations': np.array(obs_vec, dtype=np.float32),457'actions': np.array(action_vec, dtype=np.float32),458'rewards': np.zeros(num_obs, dtype=np.float32),459}460return dataset461
462
463@gin.configurable464class SawyerDrawerClose(sawyer_xyz.SawyerDrawerCloseEnv):465"""A drawer closing task."""466
467def __init__(self, random_init=False):468super(SawyerDrawerClose, self).__init__(random_init=random_init)469self.initialize_camera(self.init_camera)470
471def step(self, action):472obs = self._get_obs()473d_before = np.linalg.norm(obs[4] - self.goal[1])474s, r, done, info = super(SawyerDrawerClose, self).step(action)475d_after = np.linalg.norm(s[4] - self.goal[1])476r = d_before - d_after477done = False478return s, r, done, info479
480def init_camera(self, camera):481camera.distance = 1.482camera.lookat[0] = 0.0483camera.lookat[1] = 0.4484camera.lookat[2] = 0.3485camera.elevation = -20486camera.azimuth = 160487camera.trackbodyid = -1488
489@gin.configurable(module='SawyerDrawerClose')490def _get_expert_obs(self, hand_at_goal=True):491self.reset()492pos = np.random.uniform(0.0, 0.05)493self._set_obj_xyz(pos)494if hand_at_goal:495hand_goal = self._get_obs()[3:]496else:497hand_goal = np.random.uniform(low=[-0.2, 0.4, 0.02],498high=[0.2, 0.8, 0.3])499
500self.data.set_mocap_pos('mocap', hand_goal)501self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))502for _ in range(10):503self.do_simulation([-1, 1], self.frame_skip)504return self._get_obs()505
506def get_dataset(self, num_obs=256):507# This generates examples at ~145 observations / sec.508action_vec = [self.action_space.sample() for _ in range(num_obs)]509obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]510dataset = {511'observations': np.array(obs_vec, dtype=np.float32),512'actions': np.array(action_vec, dtype=np.float32),513'rewards': np.zeros(num_obs, dtype=np.float32),514}515return dataset516
517
518class SawyerBoxClose(sawyer_xyz.SawyerBoxCloseEnv):519"""The task is to put a lid on a box.520
521The observation dimension is 9: 3 for the hand, 3 for the lid, 3 for the
522goal.
523"""
524
525def __init__(self):526super(SawyerBoxClose, self).__init__()527self.initialize_camera(self.init_camera)528
529def _get_goal_pos(self, obs):530goal_pos = obs[-3:]531goal_pos[-1] -= 0.085532return goal_pos533
534def step(self, action):535obs = self._get_obs()536goal_pos = self._get_goal_pos(obs)537d_before = np.linalg.norm(obs[3:6] - goal_pos)538s, _, _, info = super(SawyerBoxClose, self).step(action)539d_after = np.linalg.norm(s[3:6] - goal_pos)540r = d_before - d_after541done = False542return s, r, done, info543
544def _get_expert_obs(self):545self.reset()546obs = self._get_obs()547goal_pos = obs[-3:]548goal_pos[-1] -= 0.085549self._set_obj_xyz_quat(goal_pos, self.obj_init_angle)550
551self.data.set_mocap_pos('mocap', obs[-3:])552self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))553for _ in range(10):554self.do_simulation([-1, 1], self.frame_skip)555return self._get_obs()556
557def init_camera(self, camera):558camera.distance = 1.559camera.lookat[0] = 0.0560camera.lookat[1] = 1.0561camera.lookat[2] = 0.1562camera.elevation = -10563camera.azimuth = 270564camera.trackbodyid = -1565
566def get_dataset(self, num_obs=256):567# This generates examples at ~273 observations / sec.568action_vec = [self.action_space.sample() for _ in range(num_obs)]569obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]570dataset = {571'observations': np.array(obs_vec, dtype=np.float32),572'actions': np.array(action_vec, dtype=np.float32),573'rewards': np.zeros(num_obs, dtype=np.float32),574}575return dataset576
577
578class SawyerBinPicking(sawyer_xyz.SawyerBinPickingEnv):579"""A pick and place task."""580
581def __init__(self):582super(SawyerBinPicking, self).__init__()583self.initialize_camera(self.init_camera)584
585def step(self, action):586obs = self._get_obs()587goal_pos = np.array([0.12, 0.7, 0.046])588d_before = np.linalg.norm(obs[3:6] - goal_pos)589s, _, _, info = super(SawyerBinPicking, self).step(action)590d_after = np.linalg.norm(s[3:6] - goal_pos)591r = d_before - d_after592done = False593return s, r, done, info594
595def _get_expert_obs(self):596self.reset()597goal_pos = np.random.uniform(598low=np.array([0.06, 0.64, 0.046]), high=np.array([0.18, 0.76, 0.046]))599self._set_obj_xyz_quat(goal_pos, self.obj_init_angle)600self.data.set_mocap_pos('mocap', goal_pos)601self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))602for _ in range(10):603self.do_simulation([-1, 1], self.frame_skip)604return self._get_obs()605
606@gin.configurable(module='SawyerBinPicking')607def init_camera(self, camera, mode='default'):608if mode == 'default':609camera.distance = 0.5610camera.lookat[0] = 0.6611camera.lookat[1] = 1.0612camera.lookat[2] = 0.5613camera.elevation = -20614camera.azimuth = 230615camera.trackbodyid = -1616elif mode == 'side':617camera.lookat[0] = 0.2618camera.lookat[1] = 0.7619camera.lookat[2] = 0.2620camera.distance = 0.3621camera.elevation = -30622camera.azimuth = 180623camera.trackbodyid = -1624elif mode == 'front':625camera.lookat[0] = 0.0626camera.lookat[1] = 0.9627camera.lookat[2] = 0.2628camera.distance = 0.3629camera.elevation = -30630camera.azimuth = 270631camera.trackbodyid = -1632else:633raise NotImplementedError634
635def get_dataset(self, num_obs=256):636# This generates examples at ~95 observations / sec.637action_vec = [self.action_space.sample() for _ in range(num_obs)]638obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]639dataset = {640'observations': np.array(obs_vec, dtype=np.float32),641'actions': np.array(action_vec, dtype=np.float32),642'rewards': np.zeros(num_obs, dtype=np.float32),643}644return dataset645