google-research

Форк
0
/
rce_envs.py 
644 строки · 20.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Environments for experiments with RCE.
17
"""
18

19
import inspect
20
import os
21

22
from absl import logging
23
import d4rl  # pylint: disable=unused-import
24
import gin
25
import gym
26
from metaworld.envs.mujoco import sawyer_xyz
27
import numpy as np
28
from tf_agents.environments import suite_gym
29
from tf_agents.environments import tf_py_environment
30
import tqdm
31

32
# We need to import d4rl so that gym registers the environments.
33
os.environ['SDL_VIDEODRIVER'] = 'dummy'
34

35

36
def _get_image_obs(self):
37
  # The observation returned here should be in [0, 255].
38
  obs = self.get_image(width=84, height=84)
39
  return obs[::-1]
40

41

42
@gin.configurable
43
def load_env(env_name, max_episode_steps=None):
44
  """Loads an environment.
45

46
  Args:
47
    env_name: Name of the environment.
48
    max_episode_steps: Maximum number of steps per episode.
49
  Returns:
50
    tf_env: A TFPyEnvironment.
51
  """
52
  if env_name == 'sawyer_reach':
53
    gym_env = SawyerReach()
54
    max_episode_steps = 51
55
  elif env_name == 'sawyer_push':
56
    gym_env = SawyerPush()
57
    max_episode_steps = 151
58
  elif env_name == 'sawyer_lift':
59
    gym_env = SawyerLift()
60
    max_episode_steps = 151
61
  elif env_name == 'sawyer_drawer_open':
62
    gym_env = SawyerDrawerOpen()
63
    max_episode_steps = 151
64
  elif env_name == 'sawyer_drawer_close':
65
    gym_env = SawyerDrawerClose()
66
    max_episode_steps = 151
67
  elif env_name == 'sawyer_box_close':
68
    gym_env = SawyerBoxClose()
69
    max_episode_steps = 151
70
  elif env_name == 'sawyer_bin_picking':
71
    gym_env = SawyerBinPicking()
72
    max_episode_steps = 151
73
  else:
74
    gym_spec = gym.spec(env_name)
75
    gym_env = gym_spec.make()
76
    max_episode_steps = gym_spec.max_episode_steps
77

78
  env = suite_gym.wrap_env(
79
      gym_env,
80
      max_episode_steps=max_episode_steps)
81
  tf_env = tf_py_environment.TFPyEnvironment(env)
82
  return tf_env
83

84

85
@gin.configurable(denylist=['env', 'env_name'])
86
def get_data(env, env_name, num_expert_obs=200, terminal_offset=50):
87
  """Loads the success examples.
88

89
  Args:
90
    env: A PyEnvironment for which we want to generate success examples.
91
    env_name: The name of the environment.
92
    num_expert_obs: The number of success examples to generate.
93
    terminal_offset: For the d4rl datasets, we randomly subsample the last N
94
      steps to use as success examples. The terminal_offset parameter is N.
95
  Returns:
96
    expert_obs: Array with the success examples.
97
  """
98
  if env_name in ['hammer-human-v0', 'door-human-v0', 'relocate-human-v0']:
99
    dataset = env.get_dataset()
100
    terminals = np.where(dataset['terminals'])[0]
101
    expert_obs = np.concatenate(
102
        [dataset['observations'][t - terminal_offset:t] for t in terminals],
103
        axis=0)
104
    indices = np.random.choice(
105
        len(expert_obs), size=num_expert_obs, replace=False)
106
    expert_obs = expert_obs[indices]
107
  else:
108
    # For environments where we generate the expert dataset on the fly, we can
109
    # improve performance but only generating the number of expert
110
    # observations that we'll actually use. Not all environments support this
111
    # function, so we first have to check whether the environment's
112
    # get_dataset method accepts a num_obs kwarg.
113
    get_dataset_args = inspect.getfullargspec(env.get_dataset).args
114
    if 'num_obs' in get_dataset_args:
115
      dataset = env.get_dataset(num_obs=num_expert_obs)
116
    else:
117
      dataset = env.get_dataset()
118
    indices = np.random.choice(
119
        dataset['observations'].shape[0], size=num_expert_obs, replace=False)
120
    expert_obs = dataset['observations'][indices]
121
  if 'image' in env_name:
122
    expert_obs = expert_obs.astype(np.uint8)
123
  logging.info('Done loading expert observations')
124
  return expert_obs
125

126

127
class SawyerReach(sawyer_xyz.SawyerReachPushPickPlaceEnv):
128
  """A simple reaching environment."""
129

130
  def __init__(self):
131
    super(SawyerReach, self).__init__(task_type='reach')
132
    self.initialize_camera(self.init_camera)
133

134
  def step(self, action):
135
    obs = self._get_obs()
136
    d_before = np.linalg.norm(obs[:3] - obs[3:])
137
    s, r, done, info = super(SawyerReach, self).step(action)
138
    d_after = np.linalg.norm(s[:3] - s[3:])
139
    r = d_before - d_after
140
    done = False
141
    return s, r, done, info
142

143
  @gin.configurable(module='SawyerReach')
144
  def reset(self, random=False, width=1.0, random_color=False,
145
            random_size=False):
146
    if random_color:
147
      geom_id = self.model.geom_name2id('objGeom')
148
      rgb = np.random.uniform(np.zeros(3), np.ones(3))
149
      rgba = np.concatenate([rgb, [1.0]])
150
      self.model.geom_rgba[geom_id, :] = rgba
151
    if random_size:
152
      geom_id = self.model.geom_name2id('objGeom')
153
      low = np.array([0.01, 0.005, 0.0])
154
      high = np.array([0.05, 0.045, 0.0])
155
      size = np.random.uniform(low, high)
156
      self.model.geom_size[geom_id, :] = size
157
    super(SawyerReach, self).reset()
158

159
    if random:
160
      low = np.array([-0.2, 0.4, 0.02])
161
      high = np.array([0.2, 0.8, 0.02])
162
      if width == 1:
163
        scaled_low = low
164
        scaled_high = high
165
      else:
166
        mean = (low + high) / 2.0
167
        scaled_low = mean - width * (mean - low)
168
        scaled_high = mean + width * (high - mean)
169
      puck_pos = np.random.uniform(low=scaled_low, high=scaled_high)
170
      self._set_obj_xyz_quat(puck_pos, 0.0)
171

172
    # Hide the default goals and other markers. We use the puck position as
173
    # the goal. This must happen after self._set_obj_xyz_quat(...).
174
    self._state_goal = 10 * np.ones(3)
175
    self._set_goal_marker(self._state_goal)
176
    return self._get_obs()
177

178
  def _get_expert_obs(self):
179
    self.reset()
180
    # Don't use the observation returned from self.reset because this will be
181
    # an image for SawyerReachImage.
182
    obs = self._get_obs()
183
    self.data.set_mocap_pos('mocap', obs[3:6])
184
    self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
185
    for _ in range(10):
186
      self.do_simulation([-1, 1], self.frame_skip)
187
    # Hide the markers, which get reset after every simulation step.
188
    self._set_goal_marker(self._state_goal)
189
    return self._get_obs()
190

191
  @gin.configurable(module='SawyerReach')
192
  def init_camera(self, camera, mode='default'):
193
    if mode == 'human':
194
      camera.distance = 0.5
195
      camera.lookat[0] = 0.6
196
      camera.lookat[1] = 1.0
197
      camera.lookat[2] = 0.5
198
      camera.elevation = -20
199
      camera.azimuth = 230
200
      camera.trackbodyid = -1
201
    elif mode == 'default':
202
      camera.lookat[0] = 0
203
      camera.lookat[1] = 0.85
204
      camera.lookat[2] = 0.3
205
      camera.distance = 0.4
206
      camera.elevation = -35
207
      camera.azimuth = 270
208
      camera.trackbodyid = -1
209
    elif mode == 'v2':
210
      camera.lookat[0] = 0
211
      camera.lookat[1] = 0.6
212
      camera.lookat[2] = 0.0
213
      camera.distance = 0.7
214
      camera.elevation = -35
215
      camera.azimuth = 180
216
      camera.trackbodyid = -1
217
    else:
218
      raise NotImplementedError
219

220
  def get_dataset(self, num_obs=256):
221
    # This generates examples at ~145 observations / sec. When using image
222
    # observations is slows down to ~17 FPS.
223
    action_vec = [self.action_space.sample() for _ in range(num_obs)]
224
    obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]
225
    dataset = {
226
        'observations': np.array(obs_vec, dtype=np.float32),
227
        'actions': np.array(action_vec, dtype=np.float32),
228
        'rewards': np.zeros(num_obs, dtype=np.float32),
229
    }
230
    return dataset
231

232

233
class SawyerPush(sawyer_xyz.SawyerReachPushPickPlaceEnv):
234
  """A pushing environment."""
235

236
  def __init__(self):
237
    super(SawyerPush, self).__init__(task_type='push')
238
    self.initialize_camera(self.init_camera)
239
    self._goal = np.array([0.1, 0.85, 0.02])
240

241
  def step(self, action):
242
    obs = self._get_obs()
243
    d_before = np.linalg.norm(obs[3:] - self._goal)
244
    s, _, done, info = super(SawyerPush, self).step(action)
245
    d_after = np.linalg.norm(s[3:] - self._goal)
246
    r = d_before - d_after
247
    done = False
248
    return s, r, done, info
249

250
  @gin.configurable(module='SawyerPush')
251
  def _get_expert_obs(self, hand_at_puck=True, wide=False, off_table=False):
252
    self.reset()
253
    if wide:
254
      puck_pos = np.random.uniform(low=[-0.15, 0.8, 0.02],
255
                                   high=[0.15, 0.9, 0.02])
256
    else:
257
      puck_pos = np.random.uniform(low=[0.05, 0.8, 0.02],
258
                                   high=[0.15, 0.9, 0.02])
259
    if off_table:
260
      assert not wide
261
      assert not hand_at_puck
262
      puck_pos = 10 * np.ones(3,)
263
    self._set_obj_xyz_quat(puck_pos, 0.0)
264
    if hand_at_puck:
265
      hand_goal = puck_pos
266
    else:
267
      hand_goal = np.random.uniform(low=[-0.2, 0.4, 0.02],
268
                                    high=[0.2, 0.8, 0.3])
269
    self.data.set_mocap_pos('mocap', hand_goal)
270
    self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
271
    for _ in range(10):
272
      self.do_simulation([-1, 1], self.frame_skip)
273
    return self._get_obs()
274

275
  @gin.configurable(module='SawyerPush')
276
  def init_camera(self, camera, mode='default'):
277
    if mode == 'human':
278
      camera.distance = 0.5
279
      camera.lookat[0] = 0.6
280
      camera.lookat[1] = 1.0
281
      camera.lookat[2] = 0.5
282
      camera.elevation = -20
283
      camera.azimuth = 230
284
      camera.trackbodyid = -1
285
    elif mode == 'default':
286
      camera.lookat[0] = 0
287
      camera.lookat[1] = 0.9
288
      camera.lookat[2] = 0.3
289
      camera.distance = 0.4
290
      camera.elevation = -45
291
      camera.azimuth = 270
292
      camera.trackbodyid = -1
293
    elif mode == 'front':
294
      camera.lookat[0] = 0
295
      camera.lookat[1] = 0.85
296
      camera.lookat[2] = 0.05
297
      camera.distance = 0.4
298
      camera.elevation = 0
299
      camera.azimuth = 270
300
      camera.trackbodyid = -1
301
    elif mode == 'side':
302
      camera.lookat[0] = 0
303
      camera.lookat[1] = 0.7
304
      camera.lookat[2] = 0.05
305
      camera.distance = 0.6
306
      camera.elevation = 0
307
      camera.azimuth = 180
308
      camera.trackbodyid = -1
309
    else:
310
      raise NotImplementedError
311

312
  def get_dataset(self, num_obs=256):
313
    # This generates examples at ~145 observations / sec.
314
    action_vec = [self.action_space.sample() for _ in range(num_obs)]
315
    obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]
316
    dataset = {
317
        'observations': np.array(obs_vec, dtype=np.float32),
318
        'actions': np.array(action_vec, dtype=np.float32),
319
        'rewards': np.zeros(num_obs, dtype=np.float32),
320
    }
321
    return dataset
322

323

324
class SawyerLift(sawyer_xyz.SawyerReachPushPickPlaceEnv):
325
  """A task of lifting up an object."""
326

327
  MODE = 'train'
328

329
  def __init__(self):
330
    super(SawyerLift, self).__init__(task_type='reach')
331
    self.initialize_camera(self.init_camera)
332

333
  def _get_dist(self, z):
334
    min_height, max_height = self.target_height()
335
    d_above = abs(z - max_height)
336
    d_below = abs(z - min_height)
337
    if min_height <= z <= max_height:
338
      return 0.0
339
    else:
340
      return min(d_above, d_below)
341

342
  @gin.configurable(module='SawyerLift')
343
  def target_height(self, target_height=0.1):
344
    """Values 0.1 through 0.3 are reasonable."""
345
    if isinstance(target_height, tuple) or isinstance(target_height, list):
346
      min_height, max_height = target_height
347
    else:
348
      min_height = target_height - 0.02
349
      max_height = target_height + 0.02
350
    return (min_height, max_height)
351

352
  def step(self, action):
353
    obs = self._get_obs()
354
    d_before = self._get_dist(obs[5])
355
    # d_before = abs(obs[5] - self.target_height())
356
    s, r, done, info = super(SawyerLift, self).step(action)
357
    d_after = self._get_dist(s[5])
358
    # d_after = abs(s[5] - self.target_height())
359

360
    r = d_before - d_after
361
    done = False
362
    return s, r, done, info
363

364
  def init_camera(self, camera):
365
    camera.distance = 0.5
366
    camera.lookat[0] = 0.6
367
    camera.lookat[1] = 1.0
368
    camera.lookat[2] = 0.5
369
    camera.elevation = -20
370
    camera.azimuth = 230
371
    camera.trackbodyid = -1
372

373
  @gin.configurable(module='SawyerLift')
374
  def reset(self, reset_to_goal=False):
375
    super(SawyerLift, self).reset()
376
    if reset_to_goal and self.MODE == 'train':
377
      self._get_expert_obs(reset=False)
378
    return self._get_obs()
379

380
  @gin.configurable(module='SawyerLift')
381
  def _get_expert_obs(self, reset=True):
382
    if reset:
383
      self.reset()
384
    obs = self._get_obs()
385
    puck_pos = obs[3:6]
386
    min_height, max_height = self.target_height()
387
    puck_pos[-1] = np.random.uniform(min_height, max_height)
388
    self.data.set_mocap_pos('mocap', puck_pos)
389
    self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
390
    for _ in range(10):
391
      self.do_simulation([-1, 1], self.frame_skip)
392
    # We have to set the puck position after moving the arm. Otherwise
393
    # the puck will fall while setting the arm position.
394
    self._set_obj_xyz_quat(puck_pos, 0.0)
395
    return self._get_obs()
396

397
  def get_dataset(self, num_obs=256):
398
    # This generates examples at ~145 observations / sec.
399
    action_vec = [self.action_space.sample() for _ in range(num_obs)]
400
    obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]
401
    dataset = {
402
        'observations': np.array(obs_vec, dtype=np.float32),
403
        'actions': np.array(action_vec, dtype=np.float32),
404
        'rewards': np.zeros(num_obs, dtype=np.float32),
405
    }
406
    return dataset
407

408

409
class SawyerDrawerOpen(sawyer_xyz.SawyerDrawerOpenEnv):
410
  """A drawer opening task."""
411

412
  def __init__(self):
413
    super(SawyerDrawerOpen, self).__init__()
414
    self.initialize_camera(self.init_camera)
415

416
  def step(self, action):
417
    obs = self._get_obs()
418
    d_before = np.linalg.norm(obs[4] - self.goal[1])
419
    s, r, done, info = super(SawyerDrawerOpen, self).step(action)
420
    d_after = np.linalg.norm(s[4] - self.goal[1])
421
    r = d_before - d_after
422
    done = False
423
    return s, r, done, info
424

425
  def init_camera(self, camera):
426
    camera.distance = 1.
427
    camera.lookat[0] = 0.0
428
    camera.lookat[1] = 0.4
429
    camera.lookat[2] = 0.3
430
    camera.elevation = -20
431
    camera.azimuth = 160
432
    camera.trackbodyid = -1
433

434
  @gin.configurable(module='SawyerDrawerOpen')
435
  def _get_expert_obs(self, hand_at_goal=True):
436
    self.reset()
437
    pos = np.random.uniform(-0.25, -0.15)
438
    self._set_obj_xyz(pos)
439
    if hand_at_goal:
440
      hand_goal = self._get_obs()[3:]
441
    else:
442
      hand_goal = np.random.uniform(low=[-0.2, 0.4, 0.02],
443
                                    high=[0.2, 0.8, 0.3])
444

445
    self.data.set_mocap_pos('mocap', hand_goal)
446
    self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
447
    for _ in range(10):
448
      self.do_simulation([-1, 1], self.frame_skip)
449
    return self._get_obs()
450

451
  def get_dataset(self, num_obs=256):
452
    # This generates examples at ~145 observations / sec.
453
    action_vec = [self.action_space.sample() for _ in range(num_obs)]
454
    obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]
455
    dataset = {
456
        'observations': np.array(obs_vec, dtype=np.float32),
457
        'actions': np.array(action_vec, dtype=np.float32),
458
        'rewards': np.zeros(num_obs, dtype=np.float32),
459
    }
460
    return dataset
461

462

463
@gin.configurable
464
class SawyerDrawerClose(sawyer_xyz.SawyerDrawerCloseEnv):
465
  """A drawer closing task."""
466

467
  def __init__(self, random_init=False):
468
    super(SawyerDrawerClose, self).__init__(random_init=random_init)
469
    self.initialize_camera(self.init_camera)
470

471
  def step(self, action):
472
    obs = self._get_obs()
473
    d_before = np.linalg.norm(obs[4] - self.goal[1])
474
    s, r, done, info = super(SawyerDrawerClose, self).step(action)
475
    d_after = np.linalg.norm(s[4] - self.goal[1])
476
    r = d_before - d_after
477
    done = False
478
    return s, r, done, info
479

480
  def init_camera(self, camera):
481
    camera.distance = 1.
482
    camera.lookat[0] = 0.0
483
    camera.lookat[1] = 0.4
484
    camera.lookat[2] = 0.3
485
    camera.elevation = -20
486
    camera.azimuth = 160
487
    camera.trackbodyid = -1
488

489
  @gin.configurable(module='SawyerDrawerClose')
490
  def _get_expert_obs(self, hand_at_goal=True):
491
    self.reset()
492
    pos = np.random.uniform(0.0, 0.05)
493
    self._set_obj_xyz(pos)
494
    if hand_at_goal:
495
      hand_goal = self._get_obs()[3:]
496
    else:
497
      hand_goal = np.random.uniform(low=[-0.2, 0.4, 0.02],
498
                                    high=[0.2, 0.8, 0.3])
499

500
    self.data.set_mocap_pos('mocap', hand_goal)
501
    self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
502
    for _ in range(10):
503
      self.do_simulation([-1, 1], self.frame_skip)
504
    return self._get_obs()
505

506
  def get_dataset(self, num_obs=256):
507
    # This generates examples at ~145 observations / sec.
508
    action_vec = [self.action_space.sample() for _ in range(num_obs)]
509
    obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]
510
    dataset = {
511
        'observations': np.array(obs_vec, dtype=np.float32),
512
        'actions': np.array(action_vec, dtype=np.float32),
513
        'rewards': np.zeros(num_obs, dtype=np.float32),
514
    }
515
    return dataset
516

517

518
class SawyerBoxClose(sawyer_xyz.SawyerBoxCloseEnv):
519
  """The task is to put a lid on a box.
520

521
  The observation dimension is 9: 3 for the hand, 3 for the lid, 3 for the
522
  goal.
523
  """
524

525
  def __init__(self):
526
    super(SawyerBoxClose, self).__init__()
527
    self.initialize_camera(self.init_camera)
528

529
  def _get_goal_pos(self, obs):
530
    goal_pos = obs[-3:]
531
    goal_pos[-1] -= 0.085
532
    return goal_pos
533

534
  def step(self, action):
535
    obs = self._get_obs()
536
    goal_pos = self._get_goal_pos(obs)
537
    d_before = np.linalg.norm(obs[3:6] - goal_pos)
538
    s, _, _, info = super(SawyerBoxClose, self).step(action)
539
    d_after = np.linalg.norm(s[3:6] - goal_pos)
540
    r = d_before - d_after
541
    done = False
542
    return s, r, done, info
543

544
  def _get_expert_obs(self):
545
    self.reset()
546
    obs = self._get_obs()
547
    goal_pos = obs[-3:]
548
    goal_pos[-1] -= 0.085
549
    self._set_obj_xyz_quat(goal_pos, self.obj_init_angle)
550

551
    self.data.set_mocap_pos('mocap', obs[-3:])
552
    self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
553
    for _ in range(10):
554
      self.do_simulation([-1, 1], self.frame_skip)
555
    return self._get_obs()
556

557
  def init_camera(self, camera):
558
    camera.distance = 1.
559
    camera.lookat[0] = 0.0
560
    camera.lookat[1] = 1.0
561
    camera.lookat[2] = 0.1
562
    camera.elevation = -10
563
    camera.azimuth = 270
564
    camera.trackbodyid = -1
565

566
  def get_dataset(self, num_obs=256):
567
    # This generates examples at ~273 observations / sec.
568
    action_vec = [self.action_space.sample() for _ in range(num_obs)]
569
    obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]
570
    dataset = {
571
        'observations': np.array(obs_vec, dtype=np.float32),
572
        'actions': np.array(action_vec, dtype=np.float32),
573
        'rewards': np.zeros(num_obs, dtype=np.float32),
574
    }
575
    return dataset
576

577

578
class SawyerBinPicking(sawyer_xyz.SawyerBinPickingEnv):
579
  """A pick and place task."""
580

581
  def __init__(self):
582
    super(SawyerBinPicking, self).__init__()
583
    self.initialize_camera(self.init_camera)
584

585
  def step(self, action):
586
    obs = self._get_obs()
587
    goal_pos = np.array([0.12, 0.7, 0.046])
588
    d_before = np.linalg.norm(obs[3:6] - goal_pos)
589
    s, _, _, info = super(SawyerBinPicking, self).step(action)
590
    d_after = np.linalg.norm(s[3:6] - goal_pos)
591
    r = d_before - d_after
592
    done = False
593
    return s, r, done, info
594

595
  def _get_expert_obs(self):
596
    self.reset()
597
    goal_pos = np.random.uniform(
598
        low=np.array([0.06, 0.64, 0.046]), high=np.array([0.18, 0.76, 0.046]))
599
    self._set_obj_xyz_quat(goal_pos, self.obj_init_angle)
600
    self.data.set_mocap_pos('mocap', goal_pos)
601
    self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
602
    for _ in range(10):
603
      self.do_simulation([-1, 1], self.frame_skip)
604
    return self._get_obs()
605

606
  @gin.configurable(module='SawyerBinPicking')
607
  def init_camera(self, camera, mode='default'):
608
    if mode == 'default':
609
      camera.distance = 0.5
610
      camera.lookat[0] = 0.6
611
      camera.lookat[1] = 1.0
612
      camera.lookat[2] = 0.5
613
      camera.elevation = -20
614
      camera.azimuth = 230
615
      camera.trackbodyid = -1
616
    elif mode == 'side':
617
      camera.lookat[0] = 0.2
618
      camera.lookat[1] = 0.7
619
      camera.lookat[2] = 0.2
620
      camera.distance = 0.3
621
      camera.elevation = -30
622
      camera.azimuth = 180
623
      camera.trackbodyid = -1
624
    elif mode == 'front':
625
      camera.lookat[0] = 0.0
626
      camera.lookat[1] = 0.9
627
      camera.lookat[2] = 0.2
628
      camera.distance = 0.3
629
      camera.elevation = -30
630
      camera.azimuth = 270
631
      camera.trackbodyid = -1
632
    else:
633
      raise NotImplementedError
634

635
  def get_dataset(self, num_obs=256):
636
    # This generates examples at ~95 observations / sec.
637
    action_vec = [self.action_space.sample() for _ in range(num_obs)]
638
    obs_vec = [self._get_expert_obs() for _ in tqdm.trange(num_obs)]
639
    dataset = {
640
        'observations': np.array(obs_vec, dtype=np.float32),
641
        'actions': np.array(action_vec, dtype=np.float32),
642
        'rewards': np.zeros(num_obs, dtype=np.float32),
643
    }
644
    return dataset
645

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.