google-research

Форк
0
/
fetch_envs.py 
280 строк · 9.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utility for loading the OpenAI Gym Fetch robotics environments."""
17

18
import gym
19
from gym.envs.robotics.fetch import push
20
from gym.envs.robotics.fetch import reach
21
import numpy as np
22

23

24
class FetchReachEnv(reach.FetchReachEnv):
25
  """Wrapper for the FetchReach environment."""
26

27
  def __init__(self):
28
    super(FetchReachEnv, self).__init__()
29
    self._old_observation_space = self.observation_space
30
    self._new_observation_space = gym.spaces.Box(
31
        low=np.full((20,), -np.inf),
32
        high=np.full((20,), np.inf),
33
        dtype=np.float32)
34
    self.observation_space = self._new_observation_space
35

36
  def reset(self):
37
    self.observation_space = self._old_observation_space
38
    s = super(FetchReachEnv, self).reset()
39
    self.observation_space = self._new_observation_space
40
    return self.observation(s)
41

42
  def step(self, action):
43
    s, _, _, _ = super(FetchReachEnv, self).step(action)
44
    done = False
45
    dist = np.linalg.norm(s['achieved_goal'] - s['desired_goal'])
46
    r = float(dist < 0.05)  # Default from Fetch environment.
47
    info = {}
48
    return self.observation(s), r, done, info
49

50
  def observation(self, observation):
51
    start_index = 0
52
    end_index = 3
53
    goal_pos_1 = observation['achieved_goal']
54
    goal_pos_2 = observation['observation'][start_index:end_index]
55
    assert np.all(goal_pos_1 == goal_pos_2)
56
    s = observation['observation']
57
    g = np.zeros_like(s)
58
    g[start_index:end_index] = observation['desired_goal']
59
    return np.concatenate([s, g]).astype(np.float32)
60

61

62
class FetchPushEnv(push.FetchPushEnv):
63
  """Wrapper for the FetchPush environment."""
64

65
  def __init__(self):
66
    super(FetchPushEnv, self).__init__()
67
    self._old_observation_space = self.observation_space
68
    self._new_observation_space = gym.spaces.Box(
69
        low=np.full((50,), -np.inf),
70
        high=np.full((50,), np.inf),
71
        dtype=np.float32)
72
    self.observation_space = self._new_observation_space
73

74
  def reset(self):
75
    self.observation_space = self._old_observation_space
76
    s = super(FetchPushEnv, self).reset()
77
    self.observation_space = self._new_observation_space
78
    return self.observation(s)
79

80
  def step(self, action):
81
    s, _, _, _ = super(FetchPushEnv, self).step(action)
82
    done = False
83
    dist = np.linalg.norm(s['achieved_goal'] - s['desired_goal'])
84
    r = float(dist < 0.05)
85
    info = {}
86
    return self.observation(s), r, done, info
87

88
  def observation(self, observation):
89
    start_index = 3
90
    end_index = 6
91
    goal_pos_1 = observation['achieved_goal']
92
    goal_pos_2 = observation['observation'][start_index:end_index]
93
    assert np.all(goal_pos_1 == goal_pos_2)
94
    s = observation['observation']
95
    g = np.zeros_like(s)
96
    g[:start_index] = observation['desired_goal']
97
    g[start_index:end_index] = observation['desired_goal']
98
    return np.concatenate([s, g]).astype(np.float32)
99

100

101
class FetchReachImage(reach.FetchReachEnv):
102
  """Wrapper for the FetchReach environment with image observations."""
103

104
  def __init__(self):
105
    self._dist = []
106
    self._dist_vec = []
107
    super(FetchReachImage, self).__init__()
108
    self._old_observation_space = self.observation_space
109
    self._new_observation_space = gym.spaces.Box(
110
        low=np.full((64*64*6), 0),
111
        high=np.full((64*64*6), 255),
112
        dtype=np.uint8)
113
    self.observation_space = self._new_observation_space
114
    self.sim.model.geom_rgba[1:5] = 0  # Hide the lasers
115

116
  def reset_metrics(self):
117
    self._dist_vec = []
118
    self._dist = []
119

120
  def reset(self):
121
    if self._dist:  # if len(self._dist) > 0, ...
122
      self._dist_vec.append(self._dist)
123
    self._dist = []
124

125
    # generate the new goal image
126
    self.observation_space = self._old_observation_space
127
    s = super(FetchReachImage, self).reset()
128
    self.observation_space = self._new_observation_space
129
    self._goal = s['desired_goal'].copy()
130

131
    for _ in range(10):
132
      hand = s['achieved_goal']
133
      obj = s['desired_goal']
134
      delta = obj - hand
135
      a = np.concatenate([np.clip(10 * delta, -1, 1), [0.0]])
136
      s, _, _, _ = super(FetchReachImage, self).step(a)
137

138
    self._goal_img = self.observation(s)
139

140
    self.observation_space = self._old_observation_space
141
    s = super(FetchReachImage, self).reset()
142
    self.observation_space = self._new_observation_space
143
    img = self.observation(s)
144
    dist = np.linalg.norm(s['achieved_goal'] - self._goal)
145
    self._dist.append(dist)
146
    return np.concatenate([img, self._goal_img])
147

148
  def step(self, action):
149
    s, _, _, _ = super(FetchReachImage, self).step(action)
150
    dist = np.linalg.norm(s['achieved_goal'] - self._goal)
151
    self._dist.append(dist)
152
    done = False
153
    r = float(dist < 0.05)
154
    info = {}
155
    img = self.observation(s)
156
    return np.concatenate([img, self._goal_img]), r, done, info
157

158
  def observation(self, observation):
159
    self.sim.data.site_xpos[0] = 1_000_000
160
    img = self.render(mode='rgb_array', height=64, width=64)
161
    return img.flatten()
162

163
  def _viewer_setup(self):
164
    super(FetchReachImage, self)._viewer_setup()
165
    self.viewer.cam.lookat[Ellipsis] = np.array([1.2, 0.8, 0.5])
166
    self.viewer.cam.distance = 0.8
167
    self.viewer.cam.azimuth = 180
168
    self.viewer.cam.elevation = -30
169

170

171
class FetchPushImage(push.FetchPushEnv):
172
  """Wrapper for the FetchPush environment with image observations."""
173

174
  def __init__(self, camera='camera2', start_at_obj=True, rand_y=False):
175
    self._start_at_obj = start_at_obj
176
    self._rand_y = rand_y
177
    self._camera_name = camera
178
    self._dist = []
179
    self._dist_vec = []
180
    super(FetchPushImage, self).__init__()
181
    self._old_observation_space = self.observation_space
182
    self._new_observation_space = gym.spaces.Box(
183
        low=np.full((64*64*6), 0),
184
        high=np.full((64*64*6), 255),
185
        dtype=np.uint8)
186
    self.observation_space = self._new_observation_space
187
    self.sim.model.geom_rgba[1:5] = 0  # Hide the lasers
188

189
  def reset_metrics(self):
190
    self._dist_vec = []
191
    self._dist = []
192

193
  def _move_hand_to_obj(self):
194
    s = super(FetchPushImage, self)._get_obs()
195
    for _ in range(100):
196
      hand = s['observation'][:3]
197
      obj = s['achieved_goal'] + np.array([-0.02, 0.0, 0.0])
198
      delta = obj - hand
199
      if np.linalg.norm(delta) < 0.06:
200
        break
201
      a = np.concatenate([np.clip(delta, -1, 1), [0.0]])
202
      s, _, _, _ = super(FetchPushImage, self).step(a)
203

204
  def reset(self):
205
    if self._dist:  # if len(self._dist) > 0 ...
206
      self._dist_vec.append(self._dist)
207
    self._dist = []
208

209
    # generate the new goal image
210
    self.observation_space = self._old_observation_space
211
    s = super(FetchPushImage, self).reset()
212
    self.observation_space = self._new_observation_space
213
    # Randomize object position
214
    for _ in range(8):
215
      super(FetchPushImage, self).step(np.array([-1.0, 0.0, 0.0, 0.0]))
216
    object_qpos = self.sim.data.get_joint_qpos('object0:joint')
217
    if not self._rand_y:
218
      object_qpos[1] = 0.75
219
    self.sim.data.set_joint_qpos('object0:joint', object_qpos)
220
    self._move_hand_to_obj()
221
    self._goal_img = self.observation(s)
222
    block_xyz = self.sim.data.get_joint_qpos('object0:joint')[:3]
223
    if block_xyz[2] < 0.4:  # If block has fallen off the table, recurse.
224
      print('Bad reset, recursing.')
225
      return self.reset()
226
    self._goal = block_xyz[:2].copy()
227

228
    self.observation_space = self._old_observation_space
229
    s = super(FetchPushImage, self).reset()
230
    self.observation_space = self._new_observation_space
231
    for _ in range(8):
232
      super(FetchPushImage, self).step(np.array([-1.0, 0.0, 0.0, 0.0]))
233
    object_qpos = self.sim.data.get_joint_qpos('object0:joint')
234
    object_qpos[:2] = np.array([1.15, 0.75])
235
    self.sim.data.set_joint_qpos('object0:joint', object_qpos)
236
    if self._start_at_obj:
237
      self._move_hand_to_obj()
238
    else:
239
      for _ in range(5):
240
        super(FetchPushImage, self).step(self.action_space.sample())
241

242
    block_xyz = self.sim.data.get_joint_qpos('object0:joint')[:3].copy()
243
    img = self.observation(s)
244
    dist = np.linalg.norm(block_xyz[:2] - self._goal)
245
    self._dist.append(dist)
246
    if block_xyz[2] < 0.4:  # If block has fallen off the table, recurse.
247
      print('Bad reset, recursing.')
248
      return self.reset()
249
    return np.concatenate([img, self._goal_img])
250

251
  def step(self, action):
252
    s, _, _, _ = super(FetchPushImage, self).step(action)
253
    block_xy = self.sim.data.get_joint_qpos('object0:joint')[:2]
254
    dist = np.linalg.norm(block_xy - self._goal)
255
    self._dist.append(dist)
256
    done = False
257
    r = float(dist < 0.05)  # Taken from the original task code.
258
    info = {}
259
    img = self.observation(s)
260
    return np.concatenate([img, self._goal_img]), r, done, info
261

262
  def observation(self, observation):
263
    self.sim.data.site_xpos[0] = 1_000_000
264
    img = self.render(mode='rgb_array', height=64, width=64)
265
    return img.flatten()
266

267
  def _viewer_setup(self):
268
    super(FetchPushImage, self)._viewer_setup()
269
    if self._camera_name == 'camera1':
270
      self.viewer.cam.lookat[Ellipsis] = np.array([1.2, 0.8, 0.4])
271
      self.viewer.cam.distance = 0.9
272
      self.viewer.cam.azimuth = 180
273
      self.viewer.cam.elevation = -40
274
    elif self._camera_name == 'camera2':
275
      self.viewer.cam.lookat[Ellipsis] = np.array([1.25, 0.8, 0.4])
276
      self.viewer.cam.distance = 0.65
277
      self.viewer.cam.azimuth = 90
278
      self.viewer.cam.elevation = -40
279
    else:
280
      raise NotImplementedError
281

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.