google-research
122 строки · 3.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Simple grid-world environment.
17
18The task here is to walk to the (max_x, max_y) position in a square grid.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import numpy as np
26
27from typing import Any, Dict, Tuple, Union
28
29
30class GridWalk(object):
31"""Walk on grid to target location."""
32
33def __init__(self, length, tabular_obs = True):
34"""Initializes the environment.
35
36Args:
37length: The length of the square gridworld.
38tabular_obs: Whether to use tabular observations. Otherwise observations
39are x, y coordinates.
40"""
41self._length = length
42self._tabular_obs = tabular_obs
43self._x = np.random.randint(length)
44self._y = np.random.randint(length)
45self._n_state = length ** 2
46self._n_action = 4
47self._target_x = length - 1
48self._target_y = length - 1
49
50def reset(self):
51"""Resets the agent to a random square."""
52self._x = np.random.randint(self._length)
53self._y = np.random.randint(self._length)
54return self._get_obs()
55
56def _get_obs(self):
57"""Gets current observation."""
58if self._tabular_obs:
59return self._x * self._length + self._y
60else:
61return np.array([self._x, self._y])
62
63def get_tabular_obs(self, xy_obs):
64"""Gets tabular observation given non-tabular (x,y) observation."""
65return self._length * xy_obs[Ellipsis, 0] + xy_obs[Ellipsis, 1]
66
67def get_xy_obs(self, state):
68"""Gets (x,y) coordinates given tabular observation."""
69x = state // self._length
70y = state % self._length
71return np.stack([x, y], axis=-1)
72
73def step(self, action):
74"""Perform a step in the environment.
75
76Args:
77action: A valid action (one of 0, 1, 2, 3).
78
79Returns:
80next_obs: Observation after action is applied.
81reward: Environment step reward.
82done: Whether the episode has terminated.
83info: A dictionary of additional environment information.
84
85Raises:
86ValueError: If the input action is invalid.
87"""
88if action == 0:
89if self._x < self._length - 1:
90self._x += 1
91elif action == 1:
92if self._y < self._length - 1:
93self._y += 1
94elif action == 2:
95if self._x > 0:
96self._x -= 1
97elif action == 3:
98if self._y > 0:
99self._y -= 1
100else:
101raise ValueError('Invalid action %s.' % action)
102taxi_distance = (np.abs(self._x - self._target_x) +
103np.abs(self._y - self._target_y))
104reward = np.exp(-2 * taxi_distance / self._length)
105done = False
106return self._get_obs(), reward, done, {}
107
108@property
109def num_states(self):
110return self._n_state # pytype: disable=bad-return-type # bind-properties
111
112@property
113def num_actions(self):
114return self._n_action
115
116@property
117def state_dim(self):
118return 1 if self._tabular_obs else 2
119
120@property
121def action_dim(self):
122return self._n_action
123