google-research

Форк
0
/
ndp_model.py 
352 строки · 11.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""NDP architecture."""
17

18
import diffrax
19
import flax.linen as nn
20
import jax
21
import jax.numpy as jnp
22

23

24
from hct.common import model_blocks
25
from hct.common import typing
26
from hct.common import utils
27

28

29
class NDPEncoder(nn.Module):
30
  """Encoder module for NDP.
31

32
  Implements the following maps:
33
    image, hf_obs --> (zs, g, W),
34
    where:
35
      zs: image-embedding,
36
      g: NDP-ODE goal vector.
37
      W: NDP-ODE weights matrix (flattened).
38
  """
39
  action_dim: int
40
  zs_dim: int = 64
41
  zs_width: int = 128
42
  num_basis_fncs: int = 4
43
  activation: typing.ActivationFunction = nn.relu
44

45
  def setup(self):
46

47
    self.image_map = model_blocks.ResNetBatchNorm(
48
        embed_dim=self.zs_dim, width=self.zs_width
49
    )
50

51
    self.goal_map = model_blocks.MLP([self.action_dim*2, self.action_dim],
52
                                     activate_final=False,
53
                                     activation=self.activation)
54

55
    num_weights = self.num_basis_fncs * self.action_dim
56
    self.weights_map = model_blocks.MLP([num_weights, num_weights],
57
                                        activate_final=False,
58
                                        activation=self.activation)
59

60
  def __call__(
61
      self,
62
      images,
63
      hf_obs,
64
      train = False):
65
    assert len(hf_obs.shape) == 2
66
    assert hf_obs.shape[0] == images.shape[0]
67

68
    zs = self.image_map(images, train)
69
    state_embedding = jnp.concatenate((self.activation(zs), hf_obs), axis=1)
70

71
    goals = self.goal_map(state_embedding)
72
    weights = self.weights_map(state_embedding)
73

74
    return zs, goals, weights
75

76

77
class NDPDecoder(nn.Module):
78
  """NDP Decoder.
79

80
  This models the map:
81
    zs, hf_obs --> u(0), u_dot(0),
82
    where:
83
      zs: image-embedding,
84
      u(0), u_dot(0): initial control & time-derivative for NDP-ODE.
85
  """
86
  action_dim: int
87
  zo_dim: int = 32
88
  activation: typing.ActivationFunction = nn.relu
89

90
  def setup(self):
91
    self.fusion_map = model_blocks.MLP([2*self.zo_dim]*3+[self.zo_dim],
92
                                       activate_final=True,
93
                                       activation=self.activation)
94
    out_dim = self.action_dim * 2
95
    self.out_map = model_blocks.MLP([out_dim*8, out_dim*4, out_dim*2, out_dim],
96
                                    activate_final=False,
97
                                    activation=self.activation)
98

99
  def __call__(self, zs, hf_obs):
100

101
    state_embedding = jnp.concatenate((self.activation(zs), hf_obs), axis=1)
102
    zo = self.fusion_map(state_embedding)
103
    zo_x = jnp.concatenate((zo, hf_obs), axis=1)
104
    return self.out_map(zo_x)
105

106

107
class NDP(nn.Module):
108
  """Full NDP architecture."""
109

110
  action_dim: int  # dimension of action space
111
  num_actions: int  # number of actions between two successive observations
112

113
  # Loss fnc
114
  loss_fnc: typing.LossFunction  # loss between true and predicted action
115

116
  activation: typing.ActivationFunction = nn.relu
117

118
  # Low-freq encoder
119
  zs_dim: int = 64  # image embedding dimension
120
  zs_width: int = 128  # width of image encoder network
121

122
  # Decoder network
123
  zo_dim: int = 32  # width of decoder network
124

125
  # NDP-ODE hyperparameters
126
  num_basis_fncs: int = 4
127
  alpha_p: float = 1.0
128
  alpha_u: float = 10.0
129
  beta: float = 2.5
130

131
  # Integrator
132
  ode_solver: diffrax.AbstractSolver = diffrax.Tsit5()
133
  ode_solver_dt: float = 1e-2
134
  adjoint: diffrax.AbstractAdjoint = diffrax.RecursiveCheckpointAdjoint()
135

136
  def setup(self):
137
    # Setup low-frequency encoder
138
    self.encoder = NDPEncoder(self.action_dim, self.zs_dim, self.zs_width,
139
                              self.num_basis_fncs, self.activation)
140

141
    # Setup decoder network.
142
    self.decoder = NDPDecoder(self.action_dim, self.zo_dim, self.activation)
143

144
    # Setup ODE params.
145
    rbf_centers = jnp.exp(
146
        -self.alpha_p * jnp.linspace(0, 1, self.num_basis_fncs))
147
    rbf_h = self.num_basis_fncs / rbf_centers
148

149
    def ndp_forcer(p, rbf_weights, goal, u0):
150
      psi = jnp.exp(-rbf_h * (p - rbf_centers)**2)
151
      weights_mat = jnp.reshape(rbf_weights, (self.action_dim,
152
                                              self.num_basis_fncs))
153
      return (1. / jnp.sum(psi)) * (weights_mat @ psi) * p * (goal - u0)
154
    self.ndp_forcer = ndp_forcer
155

156
    # Setup integration parameters
157
    self.step_delta = 1 / self.num_actions
158
    assert self.ode_solver_dt <= self.step_delta
159

160
    self.sample_times = jnp.arange(self.num_actions) * self.step_delta
161
    def interp_control(tau, u_samples):
162
      return jax.vmap(jnp.interp, in_axes=(None, None, 1))(
163
          tau, self.sample_times, u_samples)
164
    self.interp_control = interp_control
165

166
  def encode(self, images, hf_obs, train = False):
167
    """Encode observations (image & hf state) into latent variables.
168

169
    Args:
170
      images: (batch_size, height, width, channel)-ndarray.
171
      hf_obs: (batch_size, x_dim)-ndarray, current hf state observation.
172
      train: boolean for batchnorm
173

174
    Returns:
175
      zs: (batch_size, zs_dim): image embedding
176
      goals: (batch_size, action_dim): NDP goals.
177
      weights: (batch_size, num_weights): NDP weights.
178
    """
179
    return self.encoder(images, hf_obs, train)
180

181
  def decode(self, zs, hf_obs):
182
    """Decode image embedding & hf state into initial conditions for NDP-ODE.
183

184
    Args:
185
      zs: (batch_size, zs_dim): image embedding
186
      hf_obs: (batch_size, x_dim): current hf state observation.
187

188
    Returns:
189
      ndp_init: (batch_size, 2*action_dim): initial conditions for NDP ODE.
190
    """
191
    return self.decoder(zs, hf_obs)
192

193
  def __call__(self, images, hf_obs):
194
    """Main call function - computes NDP Flow."""
195
    return self.compute_ndp_flow(images, hf_obs, self.sample_times)
196

197
  def compute_ndp_flow(self, images, hf_obs,
198
                       pred_times):
199
    """Compute the NDP solution at the desired prediction times.
200

201
    Args:
202
      images: (batch_size, ....): images
203
      hf_obs: (batch_size, x_dim): concurrent hf observations
204
      pred_times: (num_times,): prediction times, starting at 0.
205

206
    Returns:
207
      u_pred: (batch_size, num_times, u_dim): predicted sequence of actions.
208
    """
209
    assert jnp.max(pred_times) <= 1.
210

211
    batch_size = images.shape[0]
212
    zs, goals, rbf_weights = self.encoder(images, hf_obs, train=False)
213
    init_u_udot = self.decoder(zs, hf_obs)
214
    u0s = init_u_udot[:, :self.action_dim]
215
    init_ps = jnp.ones((batch_size, 1))
216
    init_ndp_states = jnp.hstack((init_u_udot, init_ps))
217

218
    term = diffrax.ODETerm(self._ndp_ode)
219
    saveat = diffrax.SaveAt(ts=pred_times)
220

221
    def flow_one(init_state, weights, goal, u0):
222
      sol = diffrax.diffeqsolve(
223
          term, self.ode_solver, 0., pred_times[-1],
224
          self.ode_solver_dt, y0=init_state,
225
          args=(weights, goal, u0),
226
          adjoint=self.adjoint,
227
          saveat=saveat)
228
      return sol.ys[:, :self.action_dim]
229

230
    return jax.vmap(flow_one)(init_ndp_states, rbf_weights, goals, u0s)
231

232
  def compute_augmented_flow(self, images, hf_obs,
233
                             u_true, train = False):
234
    """Compute augmented flow for entire prediction period.
235

236
    Args:
237
      images: (batch_size, ....): images
238
      hf_obs: (batch_size, x_dim): concurrent hf observations
239
      u_true: (batch_size, num_actions, u_dim): observed control actions
240
      train: boolean for batchnorm.
241

242
    Returns:
243
      u_pred: (batch_size, num_actions, u_dim): predicted sequence of actions.
244
      net_loss: (batch_size,) integral of loss over the prediction period.
245
    """
246
    batch_size = images.shape[0]
247
    assert u_true.shape[1] == self.num_actions
248

249
    zs, goals, rbf_weights = self.encoder(images, hf_obs, train)
250
    init_u_udot = self.decoder(zs, hf_obs)
251
    u0s = init_u_udot[:, :self.action_dim]
252
    init_ps = jnp.ones((batch_size, 1))
253
    init_ndp_states = jnp.hstack((init_u_udot, init_ps))
254

255
    term = diffrax.ODETerm(self._aug_ode)
256
    saveat = diffrax.SaveAt(ts=self.sample_times)
257

258
    def flow_one(init_ndp_state, u_samples, weights, goal, u0):
259
      aug_state_0 = (init_ndp_state, 0.0)
260
      args = ((weights, goal, u0), u_samples)
261
      sol = diffrax.diffeqsolve(
262
          term, self.ode_solver, 0., self.sample_times[-1],
263
          self.ode_solver_dt, y0=aug_state_0,
264
          args=args,
265
          adjoint=self.adjoint,
266
          saveat=saveat)
267
      return sol.ys[0][:, :self.action_dim], sol.ys[1][-1]
268

269
    return jax.vmap(flow_one)(init_ndp_states, u_true, rbf_weights, goals, u0s)
270

271
  @nn.nowrap
272
  def _ndp_ode(self, tau, ndp_state, args):
273
    # Define the (unbatched) NDP ode over the state (u, u_dot, x).
274
    del tau
275
    rbf_weights, goal, u0 = args
276
    u, u_dot, p = jnp.split(ndp_state, [self.action_dim, 2*self.action_dim])
277
    force = self.ndp_forcer(p, rbf_weights, goal, u0)
278
    u_ddot = self.alpha_u*(self.beta * (goal - u) - u_dot) + force
279
    p_dot = -self.alpha_p * p
280
    return jnp.concatenate((u_dot, u_ddot, p_dot))
281

282
  def _step_ndp(self, ndp_state, tau, ndp_args):
283
    """Flow NDP forward by one prediction step; unbatched."""
284
    term = diffrax.ODETerm(self._ndp_ode)
285

286
    sol = diffrax.diffeqsolve(
287
        term, self.ode_solver, tau, tau + self.step_delta,
288
        self.ode_solver_dt, y0=ndp_state,
289
        args=ndp_args)
290
    return sol.ys[0], tau + self.step_delta
291

292
  @nn.nowrap
293
  def _aug_ode(self, tau, aug_state, args):
294
    # Define the (unbatched) NDP+cost ode over the state (ndp_state, cost).
295
    ndp_args, u_samples = args
296
    ndp_state, _ = aug_state
297
    ndp_out = self._ndp_ode(tau, ndp_state, ndp_args)
298
    u = ndp_state[:self.action_dim]
299
    # linearly interpolate true control samples
300
    u_true = self.interp_control(tau, u_samples)
301
    cost = self.loss_fnc(u_true, u)
302
    return ndp_out, cost
303

304
  @property
305
  def step_functions(self):
306
    """Return the 'low-frequency' and 'high-frequency' step functions."""
307

308
    def re_init(
309
        model_params, image, hf_obs
310
    ):
311
      """Re-initialize the NDP-ODE.
312

313
      Args:
314
        model_params: all params for the model.
315
        image: current image.
316
        hf_obs: concurrent hf observation.
317

318
      Returns:
319
        ndp_state: initialized NDP state (u(0), u_dot(0), phi(0)).
320
        ndp_args: ndp args (weights, goal, u0) for the NDP ODE.
321
      """
322
      zs, goal, weights = utils.unbatch_flax_fn(self.apply)(
323
          model_params, image, hf_obs, train=False, method=self.encode
324
      )
325
      ndp_state = utils.unbatch_flax_fn(self.apply)(
326
          model_params, zs, hf_obs, method=self.decode
327
      )
328
      u0 = ndp_state[: self.action_dim]
329
      return jnp.append(ndp_state, 1.0), (weights, goal, u0)
330

331
    def step_fwd(
332
        model_params,
333
        ndp_state,
334
        tau,
335
        ndp_args,
336
    ):
337
      """Flow NDP forward by one prediction step.
338

339
      Args:
340
        model_params: all params for the model.
341
        ndp_state: current NDP ODE state; (u(t), u_dot(t), phi(t)).
342
        tau: interpolation time, in between [0, 1 - (1/num_actions)].
343
        ndp_args: NDP-ODE args returned by re_init.
344

345
      Returns:
346
        ndp_state: ndp_state at (tau + step_delta).
347
        tau': tau + step_delta.
348
      """
349
      return self.apply(model_params, ndp_state, tau, ndp_args,
350
                        method=self._step_ndp)
351

352
    return re_init, step_fwd
353

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.