google-research
131 строка · 3.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Demonstrates usage of jax_particles."""
17import time
18import jax
19import jax.numpy as jnp
20from jax_particles import BaseEnvironment
21from jax_particles import Entity
22from jax_particles.renderer import ThreadedRenderer
23
24
25class Environment(BaseEnvironment):
26"""A simple demo environment. Typically you would create a module for this."""
27
28def __init__(self):
29super().__init__()
30
31# add agents
32agents = []
33agent = Entity()
34agent.name = "agent"
35agent.color = [64, 64, 64]
36agent.collideable = True
37agents.append(agent)
38
39# add collideable objects
40objects = []
41obj = Entity()
42obj.name = "object"
43obj.color = [64, 64, 204]
44obj.collideable = True
45objects.append(obj)
46
47# add landmarks
48landmarks = []
49landmark = Entity()
50landmark.name = "landmark"
51landmark.color = [64, 205, 64]
52landmark.collideable = False
53landmark.radius = 0.1
54landmarks.append(landmark)
55
56self.entities.extend(agents)
57self.entities.extend(objects)
58self.entities.extend(landmarks)
59
60self.a_shape = (1, 2) # shape of actions, one agent here
61self.o_shape = {"vec": None} # shape of observations, no observations here
62
63self._compile()
64
65def init_state(self, rng):
66"""Returns a state (p,v,misc) tuple."""
67shape = (len(self.entities), 2)
68p = jax.random.uniform(rng, shape) * (self.max_p - self.min_p) + self.min_p
69v = jnp.zeros(shape)+0.0
70return (p, v, None)
71
72def obs(self, s):
73"""Returns an observation dictionary."""
74o = {"vec": None}
75return o
76
77def reward(self, s):
78"""Returns a joint reward: np.float32."""
79return 0.0
80
81
82def main():
83"""Create an environment and step through, taking user actions."""
84batch_size = 5 # 100000
85i_batch = 1 # 25843
86stop_every_step = False
87
88env = Environment()
89renderer = ThreadedRenderer()
90
91# compile a step function
92def step(s, a):
93s = env.step(s, a)
94o = env.obs(s)
95r = env.reward(s)
96return (s, o, r)
97step = jax.vmap(step)
98step = jax.jit(step)
99
100# compile an init_state function
101init_state = env.init_state
102init_state = jax.vmap(init_state)
103init_state = jax.jit(init_state)
104
105rng = jax.random.PRNGKey(1)
106s = init_state(jax.random.split(rng, batch_size))
107last_time = time.time()
108while True:
109# extract the state of the i'th environment
110s_i = [elem[i_batch, :] if elem is not None else None for elem in s]
111
112# render
113renderer.render(env, s_i)
114
115# get user action
116a = renderer.get_action()
117a = jnp.broadcast_to(jnp.array(a), (batch_size,)+env.a_shape)
118
119# do simulation step
120s, o, r = step(s, a) # pylint: disable=unused-variable
121
122if stop_every_step:
123input("> ")
124else:
125while time.time() - last_time < env.dt:
126time.sleep(0.001)
127last_time = time.time()
128
129
130if __name__ == "__main__":
131main()
132