google-research
352 строки · 11.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""NDP architecture."""
17
18import diffrax
19import flax.linen as nn
20import jax
21import jax.numpy as jnp
22
23
24from hct.common import model_blocks
25from hct.common import typing
26from hct.common import utils
27
28
29class NDPEncoder(nn.Module):
30"""Encoder module for NDP.
31
32Implements the following maps:
33image, hf_obs --> (zs, g, W),
34where:
35zs: image-embedding,
36g: NDP-ODE goal vector.
37W: NDP-ODE weights matrix (flattened).
38"""
39action_dim: int
40zs_dim: int = 64
41zs_width: int = 128
42num_basis_fncs: int = 4
43activation: typing.ActivationFunction = nn.relu
44
45def setup(self):
46
47self.image_map = model_blocks.ResNetBatchNorm(
48embed_dim=self.zs_dim, width=self.zs_width
49)
50
51self.goal_map = model_blocks.MLP([self.action_dim*2, self.action_dim],
52activate_final=False,
53activation=self.activation)
54
55num_weights = self.num_basis_fncs * self.action_dim
56self.weights_map = model_blocks.MLP([num_weights, num_weights],
57activate_final=False,
58activation=self.activation)
59
60def __call__(
61self,
62images,
63hf_obs,
64train = False):
65assert len(hf_obs.shape) == 2
66assert hf_obs.shape[0] == images.shape[0]
67
68zs = self.image_map(images, train)
69state_embedding = jnp.concatenate((self.activation(zs), hf_obs), axis=1)
70
71goals = self.goal_map(state_embedding)
72weights = self.weights_map(state_embedding)
73
74return zs, goals, weights
75
76
77class NDPDecoder(nn.Module):
78"""NDP Decoder.
79
80This models the map:
81zs, hf_obs --> u(0), u_dot(0),
82where:
83zs: image-embedding,
84u(0), u_dot(0): initial control & time-derivative for NDP-ODE.
85"""
86action_dim: int
87zo_dim: int = 32
88activation: typing.ActivationFunction = nn.relu
89
90def setup(self):
91self.fusion_map = model_blocks.MLP([2*self.zo_dim]*3+[self.zo_dim],
92activate_final=True,
93activation=self.activation)
94out_dim = self.action_dim * 2
95self.out_map = model_blocks.MLP([out_dim*8, out_dim*4, out_dim*2, out_dim],
96activate_final=False,
97activation=self.activation)
98
99def __call__(self, zs, hf_obs):
100
101state_embedding = jnp.concatenate((self.activation(zs), hf_obs), axis=1)
102zo = self.fusion_map(state_embedding)
103zo_x = jnp.concatenate((zo, hf_obs), axis=1)
104return self.out_map(zo_x)
105
106
107class NDP(nn.Module):
108"""Full NDP architecture."""
109
110action_dim: int # dimension of action space
111num_actions: int # number of actions between two successive observations
112
113# Loss fnc
114loss_fnc: typing.LossFunction # loss between true and predicted action
115
116activation: typing.ActivationFunction = nn.relu
117
118# Low-freq encoder
119zs_dim: int = 64 # image embedding dimension
120zs_width: int = 128 # width of image encoder network
121
122# Decoder network
123zo_dim: int = 32 # width of decoder network
124
125# NDP-ODE hyperparameters
126num_basis_fncs: int = 4
127alpha_p: float = 1.0
128alpha_u: float = 10.0
129beta: float = 2.5
130
131# Integrator
132ode_solver: diffrax.AbstractSolver = diffrax.Tsit5()
133ode_solver_dt: float = 1e-2
134adjoint: diffrax.AbstractAdjoint = diffrax.RecursiveCheckpointAdjoint()
135
136def setup(self):
137# Setup low-frequency encoder
138self.encoder = NDPEncoder(self.action_dim, self.zs_dim, self.zs_width,
139self.num_basis_fncs, self.activation)
140
141# Setup decoder network.
142self.decoder = NDPDecoder(self.action_dim, self.zo_dim, self.activation)
143
144# Setup ODE params.
145rbf_centers = jnp.exp(
146-self.alpha_p * jnp.linspace(0, 1, self.num_basis_fncs))
147rbf_h = self.num_basis_fncs / rbf_centers
148
149def ndp_forcer(p, rbf_weights, goal, u0):
150psi = jnp.exp(-rbf_h * (p - rbf_centers)**2)
151weights_mat = jnp.reshape(rbf_weights, (self.action_dim,
152self.num_basis_fncs))
153return (1. / jnp.sum(psi)) * (weights_mat @ psi) * p * (goal - u0)
154self.ndp_forcer = ndp_forcer
155
156# Setup integration parameters
157self.step_delta = 1 / self.num_actions
158assert self.ode_solver_dt <= self.step_delta
159
160self.sample_times = jnp.arange(self.num_actions) * self.step_delta
161def interp_control(tau, u_samples):
162return jax.vmap(jnp.interp, in_axes=(None, None, 1))(
163tau, self.sample_times, u_samples)
164self.interp_control = interp_control
165
166def encode(self, images, hf_obs, train = False):
167"""Encode observations (image & hf state) into latent variables.
168
169Args:
170images: (batch_size, height, width, channel)-ndarray.
171hf_obs: (batch_size, x_dim)-ndarray, current hf state observation.
172train: boolean for batchnorm
173
174Returns:
175zs: (batch_size, zs_dim): image embedding
176goals: (batch_size, action_dim): NDP goals.
177weights: (batch_size, num_weights): NDP weights.
178"""
179return self.encoder(images, hf_obs, train)
180
181def decode(self, zs, hf_obs):
182"""Decode image embedding & hf state into initial conditions for NDP-ODE.
183
184Args:
185zs: (batch_size, zs_dim): image embedding
186hf_obs: (batch_size, x_dim): current hf state observation.
187
188Returns:
189ndp_init: (batch_size, 2*action_dim): initial conditions for NDP ODE.
190"""
191return self.decoder(zs, hf_obs)
192
193def __call__(self, images, hf_obs):
194"""Main call function - computes NDP Flow."""
195return self.compute_ndp_flow(images, hf_obs, self.sample_times)
196
197def compute_ndp_flow(self, images, hf_obs,
198pred_times):
199"""Compute the NDP solution at the desired prediction times.
200
201Args:
202images: (batch_size, ....): images
203hf_obs: (batch_size, x_dim): concurrent hf observations
204pred_times: (num_times,): prediction times, starting at 0.
205
206Returns:
207u_pred: (batch_size, num_times, u_dim): predicted sequence of actions.
208"""
209assert jnp.max(pred_times) <= 1.
210
211batch_size = images.shape[0]
212zs, goals, rbf_weights = self.encoder(images, hf_obs, train=False)
213init_u_udot = self.decoder(zs, hf_obs)
214u0s = init_u_udot[:, :self.action_dim]
215init_ps = jnp.ones((batch_size, 1))
216init_ndp_states = jnp.hstack((init_u_udot, init_ps))
217
218term = diffrax.ODETerm(self._ndp_ode)
219saveat = diffrax.SaveAt(ts=pred_times)
220
221def flow_one(init_state, weights, goal, u0):
222sol = diffrax.diffeqsolve(
223term, self.ode_solver, 0., pred_times[-1],
224self.ode_solver_dt, y0=init_state,
225args=(weights, goal, u0),
226adjoint=self.adjoint,
227saveat=saveat)
228return sol.ys[:, :self.action_dim]
229
230return jax.vmap(flow_one)(init_ndp_states, rbf_weights, goals, u0s)
231
232def compute_augmented_flow(self, images, hf_obs,
233u_true, train = False):
234"""Compute augmented flow for entire prediction period.
235
236Args:
237images: (batch_size, ....): images
238hf_obs: (batch_size, x_dim): concurrent hf observations
239u_true: (batch_size, num_actions, u_dim): observed control actions
240train: boolean for batchnorm.
241
242Returns:
243u_pred: (batch_size, num_actions, u_dim): predicted sequence of actions.
244net_loss: (batch_size,) integral of loss over the prediction period.
245"""
246batch_size = images.shape[0]
247assert u_true.shape[1] == self.num_actions
248
249zs, goals, rbf_weights = self.encoder(images, hf_obs, train)
250init_u_udot = self.decoder(zs, hf_obs)
251u0s = init_u_udot[:, :self.action_dim]
252init_ps = jnp.ones((batch_size, 1))
253init_ndp_states = jnp.hstack((init_u_udot, init_ps))
254
255term = diffrax.ODETerm(self._aug_ode)
256saveat = diffrax.SaveAt(ts=self.sample_times)
257
258def flow_one(init_ndp_state, u_samples, weights, goal, u0):
259aug_state_0 = (init_ndp_state, 0.0)
260args = ((weights, goal, u0), u_samples)
261sol = diffrax.diffeqsolve(
262term, self.ode_solver, 0., self.sample_times[-1],
263self.ode_solver_dt, y0=aug_state_0,
264args=args,
265adjoint=self.adjoint,
266saveat=saveat)
267return sol.ys[0][:, :self.action_dim], sol.ys[1][-1]
268
269return jax.vmap(flow_one)(init_ndp_states, u_true, rbf_weights, goals, u0s)
270
271@nn.nowrap
272def _ndp_ode(self, tau, ndp_state, args):
273# Define the (unbatched) NDP ode over the state (u, u_dot, x).
274del tau
275rbf_weights, goal, u0 = args
276u, u_dot, p = jnp.split(ndp_state, [self.action_dim, 2*self.action_dim])
277force = self.ndp_forcer(p, rbf_weights, goal, u0)
278u_ddot = self.alpha_u*(self.beta * (goal - u) - u_dot) + force
279p_dot = -self.alpha_p * p
280return jnp.concatenate((u_dot, u_ddot, p_dot))
281
282def _step_ndp(self, ndp_state, tau, ndp_args):
283"""Flow NDP forward by one prediction step; unbatched."""
284term = diffrax.ODETerm(self._ndp_ode)
285
286sol = diffrax.diffeqsolve(
287term, self.ode_solver, tau, tau + self.step_delta,
288self.ode_solver_dt, y0=ndp_state,
289args=ndp_args)
290return sol.ys[0], tau + self.step_delta
291
292@nn.nowrap
293def _aug_ode(self, tau, aug_state, args):
294# Define the (unbatched) NDP+cost ode over the state (ndp_state, cost).
295ndp_args, u_samples = args
296ndp_state, _ = aug_state
297ndp_out = self._ndp_ode(tau, ndp_state, ndp_args)
298u = ndp_state[:self.action_dim]
299# linearly interpolate true control samples
300u_true = self.interp_control(tau, u_samples)
301cost = self.loss_fnc(u_true, u)
302return ndp_out, cost
303
304@property
305def step_functions(self):
306"""Return the 'low-frequency' and 'high-frequency' step functions."""
307
308def re_init(
309model_params, image, hf_obs
310):
311"""Re-initialize the NDP-ODE.
312
313Args:
314model_params: all params for the model.
315image: current image.
316hf_obs: concurrent hf observation.
317
318Returns:
319ndp_state: initialized NDP state (u(0), u_dot(0), phi(0)).
320ndp_args: ndp args (weights, goal, u0) for the NDP ODE.
321"""
322zs, goal, weights = utils.unbatch_flax_fn(self.apply)(
323model_params, image, hf_obs, train=False, method=self.encode
324)
325ndp_state = utils.unbatch_flax_fn(self.apply)(
326model_params, zs, hf_obs, method=self.decode
327)
328u0 = ndp_state[: self.action_dim]
329return jnp.append(ndp_state, 1.0), (weights, goal, u0)
330
331def step_fwd(
332model_params,
333ndp_state,
334tau,
335ndp_args,
336):
337"""Flow NDP forward by one prediction step.
338
339Args:
340model_params: all params for the model.
341ndp_state: current NDP ODE state; (u(t), u_dot(t), phi(t)).
342tau: interpolation time, in between [0, 1 - (1/num_actions)].
343ndp_args: NDP-ODE args returned by re_init.
344
345Returns:
346ndp_state: ndp_state at (tau + step_delta).
347tau': tau + step_delta.
348"""
349return self.apply(model_params, ndp_state, tau, ndp_args,
350method=self._step_ndp)
351
352return re_init, step_fwd
353