google-research

Форк
0
131 строка · 3.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Demonstrates usage of jax_particles."""
17
import time
18
import jax
19
import jax.numpy as jnp
20
from jax_particles import BaseEnvironment
21
from jax_particles import Entity
22
from jax_particles.renderer import ThreadedRenderer
23

24

25
class Environment(BaseEnvironment):
26
  """A simple demo environment. Typically you would create a module for this."""
27

28
  def __init__(self):
29
    super().__init__()
30

31
    # add agents
32
    agents = []
33
    agent = Entity()
34
    agent.name = "agent"
35
    agent.color = [64, 64, 64]
36
    agent.collideable = True
37
    agents.append(agent)
38

39
    # add collideable objects
40
    objects = []
41
    obj = Entity()
42
    obj.name = "object"
43
    obj.color = [64, 64, 204]
44
    obj.collideable = True
45
    objects.append(obj)
46

47
    # add landmarks
48
    landmarks = []
49
    landmark = Entity()
50
    landmark.name = "landmark"
51
    landmark.color = [64, 205, 64]
52
    landmark.collideable = False
53
    landmark.radius = 0.1
54
    landmarks.append(landmark)
55

56
    self.entities.extend(agents)
57
    self.entities.extend(objects)
58
    self.entities.extend(landmarks)
59

60
    self.a_shape = (1, 2)  # shape of actions, one agent here
61
    self.o_shape = {"vec": None}  # shape of observations, no observations here
62

63
    self._compile()
64

65
  def init_state(self, rng):
66
    """Returns a state (p,v,misc) tuple."""
67
    shape = (len(self.entities), 2)
68
    p = jax.random.uniform(rng, shape) * (self.max_p - self.min_p) + self.min_p
69
    v = jnp.zeros(shape)+0.0
70
    return (p, v, None)
71

72
  def obs(self, s):
73
    """Returns an observation dictionary."""
74
    o = {"vec": None}
75
    return o
76

77
  def reward(self, s):
78
    """Returns a joint reward: np.float32."""
79
    return 0.0
80

81

82
def main():
83
  """Create an environment and step through, taking user actions."""
84
  batch_size = 5  # 100000
85
  i_batch = 1  # 25843
86
  stop_every_step = False
87

88
  env = Environment()
89
  renderer = ThreadedRenderer()
90

91
  # compile a step function
92
  def step(s, a):
93
    s = env.step(s, a)
94
    o = env.obs(s)
95
    r = env.reward(s)
96
    return (s, o, r)
97
  step = jax.vmap(step)
98
  step = jax.jit(step)
99

100
  # compile an init_state function
101
  init_state = env.init_state
102
  init_state = jax.vmap(init_state)
103
  init_state = jax.jit(init_state)
104

105
  rng = jax.random.PRNGKey(1)
106
  s = init_state(jax.random.split(rng, batch_size))
107
  last_time = time.time()
108
  while True:
109
    # extract the state of the i'th environment
110
    s_i = [elem[i_batch, :] if elem is not None else None for elem in s]
111

112
    # render
113
    renderer.render(env, s_i)
114

115
    # get user action
116
    a = renderer.get_action()
117
    a = jnp.broadcast_to(jnp.array(a), (batch_size,)+env.a_shape)
118

119
    # do simulation step
120
    s, o, r = step(s, a)  # pylint: disable=unused-variable
121

122
    if stop_every_step:
123
      input("> ")
124
    else:
125
      while time.time() - last_time < env.dt:
126
        time.sleep(0.001)
127
      last_time = time.time()
128

129

130
if __name__ == "__main__":
131
  main()
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.