google-research

Форк
0
715 строк · 23.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Definition of models."""
17

18
import abc
19
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
20

21
import chex
22
import distrax
23
import flax.linen as nn
24
import jax
25
import jax.numpy as jnp
26
import numpy as np
27
import optax
28

29

30
@jax.jit
31
def inverse_leaky_relu(y):
32
  """Inverse of the default jax.nn.leaky_relu."""
33
  alpha = jnp.where(y > 0, 1, 0.01)
34
  return y / alpha
35

36

37
@jax.jit
38
def inverse_softplus(y):
39
  """Inverse of jax.nn.softplus, adapted from TensorFlow Probability."""
40
  threshold = jnp.log(jnp.finfo(jnp.float32).eps) + 2.
41
  is_too_small = y < jnp.exp(threshold)
42
  is_too_large = y > -threshold
43
  too_small_value = jnp.log(y)
44
  too_large_value = y
45
  y = jnp.where(is_too_small | is_too_large, 1., y)
46
  x = y + jnp.log(-jnp.expm1(-y))
47
  return jnp.where(is_too_small, too_small_value,
48
                   jnp.where(is_too_large, too_large_value, x))
49

50

51
_CUSTOM_ACTIVATIONS = {
52
    'leaky_relu': jax.nn.leaky_relu,
53
    'softplus': jax.nn.softplus,
54
}
55

56

57
_INVERSE_CUSTOM_ACTIVATIONS = {
58
    'leaky_relu': inverse_leaky_relu,
59
    'softplus': inverse_softplus,
60
}
61

62

63
def create_activation_bijector(activation):
64
  """Creates a bijector for the given activation function."""
65
  if activation in _CUSTOM_ACTIVATIONS:
66
    return distrax.Lambda(
67
        forward=_CUSTOM_ACTIVATIONS[activation],
68
        inverse=_INVERSE_CUSTOM_ACTIVATIONS[activation])
69

70
  activation_fn = getattr(jax.nn, activation, None)
71
  return distrax.as_bijector(activation_fn)
72

73

74
def cartesian_to_polar(x,
75
                       y):
76
  """Converts cartesian (x, y) coordinates to polar (r, theta) coordinates."""
77
  r = jnp.sqrt(x**2 + y**2)
78
  theta = jnp.arctan2(y, x)
79
  return r, theta
80

81

82
def polar_to_cartesian(r,
83
                       theta):
84
  """Converts polar (r, theta) coordinates to cartesian (x, y) coordinates."""
85
  x = r * jnp.cos(theta)
86
  y = r * jnp.sin(theta)
87
  return x, y
88

89

90
class MLP(nn.Module):
91
  """A multi-layer perceptron (MLP)."""
92

93
  latent_sizes: Sequence[int]
94
  activation: Optional[Callable[[chex.Array], chex.Array]]
95
  skip_connections: bool = True
96
  activate_final: bool = False
97

98
  @nn.compact
99
  def __call__(self, inputs):
100
    for index, dim in enumerate(self.latent_sizes):
101
      next_inputs = nn.Dense(dim)(inputs)
102

103
      if index != len(self.latent_sizes) - 1 or self.activate_final:
104
        if self.activation is not None:
105
          next_inputs = self.activation(next_inputs)
106

107
      if self.skip_connections and next_inputs.shape == inputs.shape:
108
        next_inputs = next_inputs + inputs
109

110
      inputs = next_inputs
111
    return inputs
112

113

114
class NormalizingFlow(abc.ABC, nn.Module):
115
  """Base class for normalizing flows."""
116

117
  @abc.abstractmethod
118
  def forward(self, inputs):
119
    """Computes the forward map."""
120

121
  @abc.abstractmethod
122
  def inverse(self, inputs):
123
    """Computes the inverse map."""
124

125
  def __call__(self, inputs):
126
    return self.forward(inputs)
127

128

129
class MaskedCouplingFlowConditioner(nn.Module):
130
  """Conditioner for the masked coupling normalizing flow."""
131

132
  event_shape: Sequence[int]
133
  latent_sizes: Sequence[int]
134
  activation: Callable[[chex.Array], chex.Array]
135
  num_bijector_params: int
136

137
  @nn.compact
138
  def __call__(self, inputs):
139
    inputs = jnp.reshape(inputs, (inputs.shape[0], -1))
140
    inputs = MLP(
141
        self.latent_sizes, self.activation, activate_final=True)(
142
            inputs)
143
    inputs = nn.Dense(np.prod(self.event_shape) * self.num_bijector_params)(
144
        inputs)
145
    inputs = jnp.reshape(
146
        inputs, inputs.shape[:-1] + tuple(self.event_shape) +
147
        (self.num_bijector_params,))
148
    return inputs
149

150

151
class MaskedCouplingNormalizingFlow(NormalizingFlow):
152
  """Implements a masked coupling normalizing flow."""
153

154
  event_shape: Sequence[int]
155
  bijector_fn: Callable[[optax.Params], distrax.Bijector]
156
  conditioners: Sequence[MaskedCouplingFlowConditioner]
157

158
  def setup(self):
159
    # Alternating binary mask.
160
    mask = jnp.arange(0, np.prod(self.event_shape)) % 2
161
    mask = jnp.reshape(mask, self.event_shape)
162
    mask = mask.astype(bool)
163

164
    layers = []
165
    for conditioner in self.conditioners:
166
      layer = distrax.MaskedCoupling(
167
          mask=mask, bijector=self.bijector_fn, conditioner=conditioner)
168
      layers.append(layer)
169

170
      # Flip the mask after each layer.
171
      mask = jnp.logical_not(mask)
172

173
    # Chain layers to create the flow.
174
    self.flow = distrax.Chain(layers)
175

176
  def forward(self, inputs):
177
    """Encodes inputs as latent vectors."""
178
    return self.flow.forward(inputs)
179

180
  def inverse(self, inputs):
181
    """Applies the inverse flow to the latents."""
182
    return self.flow.inverse(inputs)
183

184

185
class OneDimensionalNormalizingFlow(NormalizingFlow):
186
  """Implements a one-dimensional normalizing flow."""
187

188
  num_layers: int
189
  activation: distrax.Bijector
190

191
  def setup(self):
192
    layers = []
193
    for index in range(self.num_layers):
194
      scale = self.param(f'scale_{index}', nn.initializers.lecun_normal(),
195
                         (1, 1))
196
      shift = self.param(f'shift_{index}', nn.initializers.lecun_normal(),
197
                         (1, 1))
198
      layer = distrax.ScalarAffine(scale=scale, shift=shift)
199
      layer = distrax.Chain([self.activation, layer])
200
      layers.append(layer)
201

202
    self.flow = distrax.Chain(layers)
203

204
  def inverse(self, inputs):
205
    """Applies the inverse flow to the latents."""
206
    return self.flow.inverse(inputs)
207

208
  def forward(self, inputs):
209
    """Encodes inputs as latent vectors."""
210
    return self.flow.forward(inputs)
211

212

213
class PointwiseNormalizingFlow(NormalizingFlow):
214
  """Implements a pointwise symplectic normalizing flow."""
215

216
  base_flow: NormalizingFlow
217
  base_flow_input_dims: int
218
  switch: bool = False
219

220
  def forward(self, inputs):
221
    num_dims = self.base_flow_input_dims
222
    assert inputs.shape[
223
        -1] == 2 * num_dims, f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.'
224

225
    first_coords = inputs[Ellipsis, :num_dims]
226
    second_coords = inputs[Ellipsis, num_dims:]
227

228
    if self.switch:
229
      first_coords, second_coords = second_coords, first_coords
230

231
    def dot_product_for_forward(coords_transformed):
232
      coords = self.base_flow.inverse(coords_transformed)
233
      return jnp.dot(coords.squeeze(axis=-1), second_coords.squeeze(axis=-1))  # pytype: disable=bad-return-type  # jnp-type
234

235
    first_coords_transformed = self.base_flow.forward(first_coords)
236
    second_coords_transformed = jax.grad(dot_product_for_forward)(
237
        first_coords_transformed)
238

239
    return jnp.concatenate(
240
        (first_coords_transformed, second_coords_transformed), axis=-1)
241

242
  def inverse(self, inputs):
243
    num_dims = self.base_flow_input_dims
244
    assert inputs.shape[
245
        -1] == 2 * num_dims, f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.'
246

247
    first_coords = inputs[Ellipsis, :num_dims]
248
    second_coords = inputs[Ellipsis, num_dims:]
249

250
    if self.switch:
251
      first_coords, second_coords = second_coords, first_coords
252

253
    def dot_product_for_inverse(coords_inverted):
254
      coords = self.base_flow.forward(coords_inverted)
255
      return jnp.dot(coords.squeeze(axis=-1), second_coords.squeeze(axis=-1))  # pytype: disable=bad-return-type  # jnp-type
256

257
    first_coords_inverted = self.base_flow.inverse(first_coords)
258
    second_coords_inverted = jax.grad(dot_product_for_inverse)(
259
        first_coords_inverted)
260

261
    return jnp.concatenate((first_coords_inverted, second_coords_inverted),
262
                           axis=-1)
263

264

265
class LinearBasedConditioner(nn.Module):
266
  """Linear module from SympNets."""
267

268
  @nn.compact
269
  def __call__(self, inputs):
270
    num_dims = inputs.shape[-1]
271
    w = self.param('w', nn.initializers.normal(0.01), (num_dims, num_dims))
272
    return inputs @ (w + w.T)
273

274

275
class ActivationBasedConditioner(nn.Module):
276
  """Activation module from SympNets."""
277

278
  activation: Callable[[chex.Array], chex.Array]
279

280
  @nn.compact
281
  def __call__(self, inputs):
282
    num_dims = inputs.shape[-1]
283
    a = self.param('a', nn.initializers.normal(0.01), (num_dims,))
284
    return self.activation(inputs) * a
285

286

287
class GradientBasedConditioner(nn.Module):
288
  """Gradient module from SympNets."""
289

290
  activation: Callable[[chex.Array], chex.Array]
291
  projection_dims: int
292
  skip_connections: bool = True
293

294
  @nn.compact
295
  def __call__(self, inputs):
296
    num_dims = inputs.shape[-1]
297
    w = self.param('w', nn.initializers.normal(0.01),
298
                   (num_dims, self.projection_dims))
299
    b = self.param('b', nn.initializers.normal(0.01), (self.projection_dims,))
300
    a = self.param('a', nn.initializers.zeros, (self.projection_dims,))
301
    gate = self.param('gate', nn.initializers.zeros, (num_dims,))
302

303
    outputs = self.activation(inputs @ w + b)
304
    outputs = outputs * a
305
    outputs = outputs @ w.T
306
    if self.skip_connections:
307
      outputs += inputs
308
    outputs *= gate
309
    return outputs
310

311

312
class ShearNormalizingFlow(NormalizingFlow):
313
  """Implements a shearing normalizing flow."""
314

315
  conditioner: nn.Module
316
  conditioner_input_dims: int
317
  switch: bool = False
318

319
  def forward(self, inputs):
320
    num_dims = self.conditioner_input_dims
321
    assert inputs.shape[-1] == 2 * num_dims, (
322
        f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.')
323

324
    first_coords = inputs[Ellipsis, :num_dims]
325
    second_coords = inputs[Ellipsis, num_dims:]
326

327
    if self.switch:
328
      first_coords, second_coords = second_coords, first_coords
329

330
    first_coords_transformed = first_coords + self.conditioner(second_coords)
331

332
    if self.switch:
333
      first_coords_transformed, second_coords = second_coords, first_coords_transformed
334

335
    return jnp.concatenate((first_coords_transformed, second_coords), axis=-1)
336

337
  def inverse(self, inputs):
338
    num_dims = self.conditioner_input_dims
339
    assert inputs.shape[-1] == 2 * num_dims, (
340
        f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.')
341

342
    first_coords = inputs[Ellipsis, :num_dims]
343
    second_coords = inputs[Ellipsis, num_dims:]
344

345
    if self.switch:
346
      first_coords, second_coords = second_coords, first_coords
347

348
    first_coords_inverted = first_coords - self.conditioner(second_coords)
349

350
    if self.switch:
351
      first_coords_inverted, second_coords = second_coords, first_coords_inverted
352

353
    return jnp.concatenate((first_coords_inverted, second_coords), axis=-1)
354

355

356
class SymplecticLinearFlow(NormalizingFlow):
357
  """Implements a SymplecticLinear layer from 'Learning Symmetries of Classical Integrable Systems', consisting of a shift, scale and rotation."""
358

359
  operation_input_dims: int
360

361
  def setup(self):
362
    num_dims = self.operation_input_dims
363
    self.shift_val = self.param('shift_val', nn.initializers.zeros, (num_dims,))
364
    self.scale_val = self.param('scale_val', nn.initializers.ones, (num_dims,))
365
    self.rotate_val = self.param('rotate_val', nn.initializers.zeros,
366
                                 (num_dims,))
367

368
  def extract_coords(self, inputs):
369
    """Returns the two sets of coordinates."""
370
    num_dims = self.operation_input_dims
371
    assert inputs.shape[
372
        -1] == 2 * num_dims, f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.'
373

374
    first_coords = inputs[Ellipsis, :num_dims]
375
    second_coords = inputs[Ellipsis, num_dims:]
376
    return first_coords, second_coords
377

378
  def forward(self, inputs):
379
    """Runs the forward pass."""
380
    first_coords, second_coords = self.extract_coords(inputs)
381
    first_coords, second_coords = self.shift(
382
        first_coords, second_coords, forward=True)
383
    first_coords, second_coords = self.scale(
384
        first_coords, second_coords, forward=True)
385
    first_coords, second_coords = self.rotate(
386
        first_coords, second_coords, forward=True)
387
    return jnp.concatenate((first_coords, second_coords), axis=-1)
388

389
  def inverse(self, inputs):
390
    """Computes the inverse of self.forward()."""
391
    first_coords, second_coords = self.extract_coords(inputs)
392
    first_coords, second_coords = self.rotate(
393
        first_coords, second_coords, forward=False)
394
    first_coords, second_coords = self.scale(
395
        first_coords, second_coords, forward=False)
396
    first_coords, second_coords = self.shift(
397
        first_coords, second_coords, forward=False)
398
    return jnp.concatenate((first_coords, second_coords), axis=-1)
399

400
  def shift(self, first_coords, second_coords,
401
            forward):
402
    """Performs the shift transformation on one of the coordinates."""
403
    shift = self.shift_val
404
    if not forward:
405
      shift = -shift
406
    first_coords_shifted = first_coords + second_coords * shift
407
    second_coords_shifted = second_coords
408
    return first_coords_shifted, second_coords_shifted
409

410
  def scale(self, first_coords, second_coords,
411
            forward):
412
    """Scales the coordinates in a symplectic manner."""
413
    scale = self.scale_val
414
    if not forward:
415
      scale = 1 / scale
416
    first_coords_scaled = first_coords * scale
417
    second_coords_scaled = second_coords / scale
418
    return first_coords_scaled, second_coords_scaled
419

420
  def rotate(self, first_coords, second_coords,
421
             forward):
422
    """Rotates all of the coordinates."""
423
    theta = self.rotate_val
424
    if not forward:
425
      theta = -theta
426
    first_coords_rotated = first_coords * jnp.cos(
427
        theta) - second_coords * jnp.sin(theta)
428
    second_coords_rotated = first_coords * jnp.sin(
429
        theta) + second_coords * jnp.cos(theta)
430
    return first_coords_rotated, second_coords_rotated
431

432

433
class SequentialFlow(NormalizingFlow):
434
  """Adaptation of nn.Sequential() for flows."""
435

436
  flows: Sequence[NormalizingFlow]
437

438
  def forward(self, inputs):
439
    for flow in self.flows:
440
      inputs = flow.forward(inputs)
441
    return inputs
442

443
  def inverse(self, inputs):
444
    for flow in reversed(self.flows):
445
      inputs = flow.inverse(inputs)
446
    return inputs
447

448
  def __call__(self, inputs):
449
    return self.forward(inputs)
450

451

452
class CoordinateEncoder(abc.ABC, nn.Module):
453
  """Base class for encoders."""
454

455
  @abc.abstractmethod
456
  def __call__(self, positions,
457
               momentums):
458
    """Returns corresponding angles and momentums by encoding inputs."""
459

460

461
class CoordinateDecoder(abc.ABC, nn.Module):
462
  """Base class for decoders."""
463

464
  @abc.abstractmethod
465
  def __call__(self, actions,
466
               angles):
467
    """Returns corresponding positions and momentums by decoding inputs."""
468

469

470
class MLPEncoder(CoordinateEncoder):
471
  """MLP-based encoder."""
472

473
  position_encoder: nn.Module
474
  momentum_encoder: nn.Module
475
  transform_fn: nn.Module
476
  latent_position_decoder: nn.Module
477
  latent_momentum_decoder: nn.Module
478

479
  def __call__(self, positions,
480
               momentums):
481
    # Encode input coordinates.
482
    positions = self.position_encoder(positions)
483
    momentums = self.momentum_encoder(momentums)
484

485
    # Transform to new coordinates.
486
    coords = jnp.concatenate([positions, momentums], axis=-1)
487
    coords = self.transform_fn(coords)
488

489
    # Decode to final coordinates.
490
    latent_positions = self.latent_position_decoder(coords)
491
    latent_momentums = self.latent_momentum_decoder(coords)
492
    return latent_positions, latent_momentums
493

494

495
class MLPDecoder(CoordinateDecoder):
496
  """MLP-based decoder."""
497

498
  latent_position_encoder: nn.Module
499
  latent_momentum_encoder: nn.Module
500
  transform_fn: nn.Module
501
  position_decoder: nn.Module
502
  momentum_decoder: nn.Module
503

504
  def __call__(self, latent_positions,
505
               latent_momentums):
506
    # Encode input coordinates.
507
    latent_positions = self.latent_position_encoder(latent_positions)
508
    latent_momentums = self.latent_momentum_encoder(latent_momentums)
509

510
    # Transform to new coordinates.
511
    coords = jnp.concatenate([latent_positions, latent_momentums], axis=-1)
512
    coords = self.transform_fn(coords)
513

514
    # Decode to final coordinates.
515
    positions = self.position_decoder(coords)
516
    momentums = self.momentum_decoder(coords)
517
    return positions, momentums
518

519

520
class FlowEncoder(CoordinateEncoder):
521
  """Flow-based encoder for the Action-Angle Neural Network."""
522

523
  flow: NormalizingFlow
524

525
  def __call__(self, positions,
526
               momentums):
527
    # Pass through forward flow to obtain latent positions and momentums.
528
    coords = jnp.concatenate([positions, momentums], axis=-1)
529
    coords = self.flow.forward(coords)
530

531
    assert len(coords.shape) == 2, coords.shape
532
    assert coords.shape[-1] % 2 == 0, coords.shape
533

534
    num_positions = coords.shape[-1] // 2
535
    latent_positions = coords[Ellipsis, :num_positions]
536
    latent_momentums = coords[Ellipsis, num_positions:]
537
    return latent_positions, latent_momentums
538

539

540
class FlowDecoder(CoordinateDecoder):
541
  """Flow-based decoder for the Action-Angle Neural Network."""
542

543
  flow: NormalizingFlow
544

545
  def __call__(self, latent_positions,
546
               latent_momentums):
547
    # Pass through inverse flow to obtain positions and momentums.
548
    coords = jnp.concatenate([latent_positions, latent_momentums], axis=-1)
549
    coords = self.flow.inverse(coords)
550

551
    assert len(coords.shape) == 2, coords.shape
552
    assert coords.shape[-1] % 2 == 0, coords.shape
553

554
    num_positions = coords.shape[-1] // 2
555
    positions = coords[Ellipsis, :num_positions]
556
    momentums = coords[Ellipsis, num_positions:]
557
    return positions, momentums
558

559

560
class ActionAngleNetwork(nn.Module):
561
  """Implementation of an Action-Angle Neural Network."""
562

563
  encoder: CoordinateEncoder
564
  angular_velocity_net: nn.Module
565
  decoder: CoordinateDecoder
566
  polar_action_angles: bool
567
  single_step_predictions: bool = True
568

569
  def predict_single_step(
570
      self, positions, momentums,
571
      time_deltas
572
  ):
573
    """Predicts future coordinates with one-step prediction."""
574
    time_deltas = jnp.squeeze(time_deltas)
575
    time_deltas = jnp.expand_dims(time_deltas, axis=range(time_deltas.ndim, 2))
576
    assert time_deltas.ndim == 2
577

578
    # Encode.
579
    current_latent_positions, current_latent_momentums = self.encoder(
580
        positions, momentums)
581
    if self.polar_action_angles:
582
      actions, current_angles = jax.vmap(cartesian_to_polar)(
583
          current_latent_positions, current_latent_momentums)
584
    else:
585
      actions, current_angles = current_latent_positions, current_latent_momentums
586

587
    # Compute angular velocities.
588
    angular_velocities = self.angular_velocity_net(actions)
589
    assert angular_velocities.shape[-1] == current_angles.shape[-1]
590

591
    # Fast-forward.
592
    future_angles = current_angles + angular_velocities * time_deltas
593
    if self.polar_action_angles:
594
      future_angles = (future_angles + jnp.pi) % (2 * jnp.pi) - (jnp.pi)
595

596
    # Decode.
597
    if self.polar_action_angles:
598
      future_latent_positions, future_latent_momentums = jax.vmap(
599
          polar_to_cartesian)(actions, future_angles)
600
    else:
601
      future_latent_positions, future_latent_momentums = actions, future_angles
602
    predicted_positions, predicted_momentums = self.decoder(
603
        future_latent_positions, future_latent_momentums)
604

605
    return predicted_positions, predicted_momentums, dict(
606
        current_latent_positions=current_latent_positions,
607
        current_latent_momentums=current_latent_momentums,
608
        actions=actions,
609
        current_angles=current_angles,
610
        angular_velocities=angular_velocities,
611
        future_angles=future_angles,
612
        future_latent_positions=future_latent_positions,
613
        future_latent_momentums=future_latent_momentums)
614

615
  def predict_multi_step(
616
      self, init_positions, init_momentums,
617
      time_deltas
618
  ):
619
    """Predicts future coordinates with multi-step prediction."""
620
    time_deltas = jnp.expand_dims(time_deltas, axis=range(time_deltas.ndim, 2))
621
    assert time_deltas.ndim == 2
622

623
    # init_positions and init_positions have shape [1 x num_trajectories].
624
    assert len(init_positions.shape) == 2
625
    assert len(init_momentums.shape) == 2
626
    assert init_positions.shape[0] == 1
627
    assert init_momentums.shape[0] == 1
628

629
    # Encode.
630
    current_latent_positions, current_latent_momentums = self.encoder(
631
        init_positions, init_momentums)
632
    if self.polar_action_angles:
633
      actions, current_angles = jax.vmap(cartesian_to_polar)(
634
          current_latent_positions, current_latent_momentums)
635
    else:
636
      actions, current_angles = current_latent_positions, current_latent_momentums
637

638
    # Compute angular velocities.
639
    angular_velocities = self.angular_velocity_net(actions)
640
    assert angular_velocities.shape[-1] == current_angles.shape[-1]
641

642
    # Fast-forward.
643
    future_angles = current_angles + angular_velocities * time_deltas
644

645
    # actions has shape [1 x num_trajectories].
646
    # future_angles has shape [T x num_trajectories].
647
    if self.polar_action_angles:
648
      future_angles = (future_angles + jnp.pi) % (2 * jnp.pi) - (jnp.pi)
649
      future_latent_positions, future_latent_momentums = jax.vmap(
650
          polar_to_cartesian, in_axes=(None, 0))(actions[0], future_angles)
651
    else:
652
      future_latent_positions, future_latent_momentums = jax.vmap(
653
          lambda x, y: (x, y), in_axes=(None, 0))(actions[0], future_angles)
654

655
    # predicted_positions has shape [T x num_trajectories].
656
    # predicted_momentums has shape [T x num_trajectories].
657
    predicted_positions, predicted_momentums = self.decoder(
658
        future_latent_positions, future_latent_momentums)
659

660
    return predicted_positions, predicted_momentums, dict(
661
        current_latent_positions=current_latent_positions,
662
        current_latent_momentums=current_latent_momentums,
663
        actions=actions,
664
        current_angles=current_angles,
665
        angular_velocities=angular_velocities,
666
        future_angles=future_angles,
667
        future_latent_positions=future_latent_positions,
668
        future_latent_momentums=future_latent_momentums)
669

670
  def encode_decode(self, positions,
671
                    momentums):
672
    """Encodes and decodes the given coordinates."""
673
    actions, current_angles = self.encoder(positions, momentums)
674
    return self.decoder(actions, current_angles)
675

676
  def __call__(
677
      self, positions, momentums,
678
      time_deltas
679
  ):
680
    time_deltas = jnp.asarray(time_deltas)
681
    if self.single_step_predictions or time_deltas.ndim == 0:
682
      return self.predict_single_step(positions, momentums, time_deltas)
683
    return self.predict_multi_step(positions, momentums, time_deltas)
684

685

686
class EulerUpdateNetwork(nn.Module):
687
  """A neural network that performs Euler updates with predicted position and momentum derivatives."""
688

689
  encoder: CoordinateEncoder
690
  derivative_net: nn.Module
691
  decoder: CoordinateDecoder
692

693
  @nn.compact
694
  def __call__(self, positions, momentums,
695
               time_delta):
696
    # Encode.
697
    positions, momentums = self.encoder(positions, momentums)
698

699
    # Predict derivatives for each coordinate.
700
    coords = jnp.concatenate([positions, momentums], axis=-1)
701
    derivatives = self.derivative_net(coords)
702

703
    # Unpack.
704
    num_positions = derivatives.shape[-1] // 2
705
    position_derivative = derivatives[Ellipsis, :num_positions]
706
    momentum_derivative = derivatives[Ellipsis, num_positions:]
707

708
    # Perform Euler update.
709
    predicted_positions = positions + position_derivative * time_delta
710
    predicted_momentums = momentums + momentum_derivative * time_delta
711

712
    # Decode.
713
    predicted_positions, predicted_momentums = self.decoder(
714
        predicted_positions, predicted_momentums)
715
    return predicted_positions, predicted_momentums, None
716

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.