google-research

Форк
0
122 строки · 3.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Simple grid-world environment.
17

18
The task here is to walk to the (max_x, max_y) position in a square grid.
19
"""
20

21
from __future__ import absolute_import
22
from __future__ import division
23
from __future__ import print_function
24

25
import numpy as np
26

27
from typing import Any, Dict, Tuple, Union
28

29

30
class GridWalk(object):
31
  """Walk on grid to target location."""
32

33
  def __init__(self, length, tabular_obs = True):
34
    """Initializes the environment.
35

36
    Args:
37
      length: The length of the square gridworld.
38
      tabular_obs: Whether to use tabular observations. Otherwise observations
39
        are x, y coordinates.
40
    """
41
    self._length = length
42
    self._tabular_obs = tabular_obs
43
    self._x = np.random.randint(length)
44
    self._y = np.random.randint(length)
45
    self._n_state = length ** 2
46
    self._n_action = 4
47
    self._target_x = length - 1
48
    self._target_y = length - 1
49

50
  def reset(self):
51
    """Resets the agent to a random square."""
52
    self._x = np.random.randint(self._length)
53
    self._y = np.random.randint(self._length)
54
    return self._get_obs()
55

56
  def _get_obs(self):
57
    """Gets current observation."""
58
    if self._tabular_obs:
59
      return self._x * self._length + self._y
60
    else:
61
      return np.array([self._x, self._y])
62

63
  def get_tabular_obs(self, xy_obs):
64
    """Gets tabular observation given non-tabular (x,y) observation."""
65
    return self._length * xy_obs[Ellipsis, 0] + xy_obs[Ellipsis, 1]
66

67
  def get_xy_obs(self, state):
68
    """Gets (x,y) coordinates given tabular observation."""
69
    x = state // self._length
70
    y = state % self._length
71
    return np.stack([x, y], axis=-1)
72

73
  def step(self, action):
74
    """Perform a step in the environment.
75

76
    Args:
77
      action: A valid action (one of 0, 1, 2, 3).
78

79
    Returns:
80
      next_obs: Observation after action is applied.
81
      reward: Environment step reward.
82
      done: Whether the episode has terminated.
83
      info: A dictionary of additional environment information.
84

85
    Raises:
86
      ValueError: If the input action is invalid.
87
    """
88
    if action == 0:
89
      if self._x < self._length - 1:
90
        self._x += 1
91
    elif action == 1:
92
      if self._y < self._length - 1:
93
        self._y += 1
94
    elif action == 2:
95
      if self._x > 0:
96
        self._x -= 1
97
    elif action == 3:
98
      if self._y > 0:
99
        self._y -= 1
100
    else:
101
      raise ValueError('Invalid action %s.' % action)
102
    taxi_distance = (np.abs(self._x - self._target_x) +
103
                     np.abs(self._y - self._target_y))
104
    reward = np.exp(-2 * taxi_distance / self._length)
105
    done = False
106
    return self._get_obs(), reward, done, {}
107

108
  @property
109
  def num_states(self):
110
    return self._n_state  # pytype: disable=bad-return-type  # bind-properties
111

112
  @property
113
  def num_actions(self):
114
    return self._n_action
115

116
  @property
117
  def state_dim(self):
118
    return 1 if self._tabular_obs else 2
119

120
  @property
121
  def action_dim(self):
122
    return self._n_action
123

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.