google-research
715 строк · 23.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Definition of models."""
17
18import abc
19from typing import Callable, Dict, Optional, Sequence, Tuple, Union
20
21import chex
22import distrax
23import flax.linen as nn
24import jax
25import jax.numpy as jnp
26import numpy as np
27import optax
28
29
30@jax.jit
31def inverse_leaky_relu(y):
32"""Inverse of the default jax.nn.leaky_relu."""
33alpha = jnp.where(y > 0, 1, 0.01)
34return y / alpha
35
36
37@jax.jit
38def inverse_softplus(y):
39"""Inverse of jax.nn.softplus, adapted from TensorFlow Probability."""
40threshold = jnp.log(jnp.finfo(jnp.float32).eps) + 2.
41is_too_small = y < jnp.exp(threshold)
42is_too_large = y > -threshold
43too_small_value = jnp.log(y)
44too_large_value = y
45y = jnp.where(is_too_small | is_too_large, 1., y)
46x = y + jnp.log(-jnp.expm1(-y))
47return jnp.where(is_too_small, too_small_value,
48jnp.where(is_too_large, too_large_value, x))
49
50
51_CUSTOM_ACTIVATIONS = {
52'leaky_relu': jax.nn.leaky_relu,
53'softplus': jax.nn.softplus,
54}
55
56
57_INVERSE_CUSTOM_ACTIVATIONS = {
58'leaky_relu': inverse_leaky_relu,
59'softplus': inverse_softplus,
60}
61
62
63def create_activation_bijector(activation):
64"""Creates a bijector for the given activation function."""
65if activation in _CUSTOM_ACTIVATIONS:
66return distrax.Lambda(
67forward=_CUSTOM_ACTIVATIONS[activation],
68inverse=_INVERSE_CUSTOM_ACTIVATIONS[activation])
69
70activation_fn = getattr(jax.nn, activation, None)
71return distrax.as_bijector(activation_fn)
72
73
74def cartesian_to_polar(x,
75y):
76"""Converts cartesian (x, y) coordinates to polar (r, theta) coordinates."""
77r = jnp.sqrt(x**2 + y**2)
78theta = jnp.arctan2(y, x)
79return r, theta
80
81
82def polar_to_cartesian(r,
83theta):
84"""Converts polar (r, theta) coordinates to cartesian (x, y) coordinates."""
85x = r * jnp.cos(theta)
86y = r * jnp.sin(theta)
87return x, y
88
89
90class MLP(nn.Module):
91"""A multi-layer perceptron (MLP)."""
92
93latent_sizes: Sequence[int]
94activation: Optional[Callable[[chex.Array], chex.Array]]
95skip_connections: bool = True
96activate_final: bool = False
97
98@nn.compact
99def __call__(self, inputs):
100for index, dim in enumerate(self.latent_sizes):
101next_inputs = nn.Dense(dim)(inputs)
102
103if index != len(self.latent_sizes) - 1 or self.activate_final:
104if self.activation is not None:
105next_inputs = self.activation(next_inputs)
106
107if self.skip_connections and next_inputs.shape == inputs.shape:
108next_inputs = next_inputs + inputs
109
110inputs = next_inputs
111return inputs
112
113
114class NormalizingFlow(abc.ABC, nn.Module):
115"""Base class for normalizing flows."""
116
117@abc.abstractmethod
118def forward(self, inputs):
119"""Computes the forward map."""
120
121@abc.abstractmethod
122def inverse(self, inputs):
123"""Computes the inverse map."""
124
125def __call__(self, inputs):
126return self.forward(inputs)
127
128
129class MaskedCouplingFlowConditioner(nn.Module):
130"""Conditioner for the masked coupling normalizing flow."""
131
132event_shape: Sequence[int]
133latent_sizes: Sequence[int]
134activation: Callable[[chex.Array], chex.Array]
135num_bijector_params: int
136
137@nn.compact
138def __call__(self, inputs):
139inputs = jnp.reshape(inputs, (inputs.shape[0], -1))
140inputs = MLP(
141self.latent_sizes, self.activation, activate_final=True)(
142inputs)
143inputs = nn.Dense(np.prod(self.event_shape) * self.num_bijector_params)(
144inputs)
145inputs = jnp.reshape(
146inputs, inputs.shape[:-1] + tuple(self.event_shape) +
147(self.num_bijector_params,))
148return inputs
149
150
151class MaskedCouplingNormalizingFlow(NormalizingFlow):
152"""Implements a masked coupling normalizing flow."""
153
154event_shape: Sequence[int]
155bijector_fn: Callable[[optax.Params], distrax.Bijector]
156conditioners: Sequence[MaskedCouplingFlowConditioner]
157
158def setup(self):
159# Alternating binary mask.
160mask = jnp.arange(0, np.prod(self.event_shape)) % 2
161mask = jnp.reshape(mask, self.event_shape)
162mask = mask.astype(bool)
163
164layers = []
165for conditioner in self.conditioners:
166layer = distrax.MaskedCoupling(
167mask=mask, bijector=self.bijector_fn, conditioner=conditioner)
168layers.append(layer)
169
170# Flip the mask after each layer.
171mask = jnp.logical_not(mask)
172
173# Chain layers to create the flow.
174self.flow = distrax.Chain(layers)
175
176def forward(self, inputs):
177"""Encodes inputs as latent vectors."""
178return self.flow.forward(inputs)
179
180def inverse(self, inputs):
181"""Applies the inverse flow to the latents."""
182return self.flow.inverse(inputs)
183
184
185class OneDimensionalNormalizingFlow(NormalizingFlow):
186"""Implements a one-dimensional normalizing flow."""
187
188num_layers: int
189activation: distrax.Bijector
190
191def setup(self):
192layers = []
193for index in range(self.num_layers):
194scale = self.param(f'scale_{index}', nn.initializers.lecun_normal(),
195(1, 1))
196shift = self.param(f'shift_{index}', nn.initializers.lecun_normal(),
197(1, 1))
198layer = distrax.ScalarAffine(scale=scale, shift=shift)
199layer = distrax.Chain([self.activation, layer])
200layers.append(layer)
201
202self.flow = distrax.Chain(layers)
203
204def inverse(self, inputs):
205"""Applies the inverse flow to the latents."""
206return self.flow.inverse(inputs)
207
208def forward(self, inputs):
209"""Encodes inputs as latent vectors."""
210return self.flow.forward(inputs)
211
212
213class PointwiseNormalizingFlow(NormalizingFlow):
214"""Implements a pointwise symplectic normalizing flow."""
215
216base_flow: NormalizingFlow
217base_flow_input_dims: int
218switch: bool = False
219
220def forward(self, inputs):
221num_dims = self.base_flow_input_dims
222assert inputs.shape[
223-1] == 2 * num_dims, f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.'
224
225first_coords = inputs[Ellipsis, :num_dims]
226second_coords = inputs[Ellipsis, num_dims:]
227
228if self.switch:
229first_coords, second_coords = second_coords, first_coords
230
231def dot_product_for_forward(coords_transformed):
232coords = self.base_flow.inverse(coords_transformed)
233return jnp.dot(coords.squeeze(axis=-1), second_coords.squeeze(axis=-1)) # pytype: disable=bad-return-type # jnp-type
234
235first_coords_transformed = self.base_flow.forward(first_coords)
236second_coords_transformed = jax.grad(dot_product_for_forward)(
237first_coords_transformed)
238
239return jnp.concatenate(
240(first_coords_transformed, second_coords_transformed), axis=-1)
241
242def inverse(self, inputs):
243num_dims = self.base_flow_input_dims
244assert inputs.shape[
245-1] == 2 * num_dims, f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.'
246
247first_coords = inputs[Ellipsis, :num_dims]
248second_coords = inputs[Ellipsis, num_dims:]
249
250if self.switch:
251first_coords, second_coords = second_coords, first_coords
252
253def dot_product_for_inverse(coords_inverted):
254coords = self.base_flow.forward(coords_inverted)
255return jnp.dot(coords.squeeze(axis=-1), second_coords.squeeze(axis=-1)) # pytype: disable=bad-return-type # jnp-type
256
257first_coords_inverted = self.base_flow.inverse(first_coords)
258second_coords_inverted = jax.grad(dot_product_for_inverse)(
259first_coords_inverted)
260
261return jnp.concatenate((first_coords_inverted, second_coords_inverted),
262axis=-1)
263
264
265class LinearBasedConditioner(nn.Module):
266"""Linear module from SympNets."""
267
268@nn.compact
269def __call__(self, inputs):
270num_dims = inputs.shape[-1]
271w = self.param('w', nn.initializers.normal(0.01), (num_dims, num_dims))
272return inputs @ (w + w.T)
273
274
275class ActivationBasedConditioner(nn.Module):
276"""Activation module from SympNets."""
277
278activation: Callable[[chex.Array], chex.Array]
279
280@nn.compact
281def __call__(self, inputs):
282num_dims = inputs.shape[-1]
283a = self.param('a', nn.initializers.normal(0.01), (num_dims,))
284return self.activation(inputs) * a
285
286
287class GradientBasedConditioner(nn.Module):
288"""Gradient module from SympNets."""
289
290activation: Callable[[chex.Array], chex.Array]
291projection_dims: int
292skip_connections: bool = True
293
294@nn.compact
295def __call__(self, inputs):
296num_dims = inputs.shape[-1]
297w = self.param('w', nn.initializers.normal(0.01),
298(num_dims, self.projection_dims))
299b = self.param('b', nn.initializers.normal(0.01), (self.projection_dims,))
300a = self.param('a', nn.initializers.zeros, (self.projection_dims,))
301gate = self.param('gate', nn.initializers.zeros, (num_dims,))
302
303outputs = self.activation(inputs @ w + b)
304outputs = outputs * a
305outputs = outputs @ w.T
306if self.skip_connections:
307outputs += inputs
308outputs *= gate
309return outputs
310
311
312class ShearNormalizingFlow(NormalizingFlow):
313"""Implements a shearing normalizing flow."""
314
315conditioner: nn.Module
316conditioner_input_dims: int
317switch: bool = False
318
319def forward(self, inputs):
320num_dims = self.conditioner_input_dims
321assert inputs.shape[-1] == 2 * num_dims, (
322f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.')
323
324first_coords = inputs[Ellipsis, :num_dims]
325second_coords = inputs[Ellipsis, num_dims:]
326
327if self.switch:
328first_coords, second_coords = second_coords, first_coords
329
330first_coords_transformed = first_coords + self.conditioner(second_coords)
331
332if self.switch:
333first_coords_transformed, second_coords = second_coords, first_coords_transformed
334
335return jnp.concatenate((first_coords_transformed, second_coords), axis=-1)
336
337def inverse(self, inputs):
338num_dims = self.conditioner_input_dims
339assert inputs.shape[-1] == 2 * num_dims, (
340f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.')
341
342first_coords = inputs[Ellipsis, :num_dims]
343second_coords = inputs[Ellipsis, num_dims:]
344
345if self.switch:
346first_coords, second_coords = second_coords, first_coords
347
348first_coords_inverted = first_coords - self.conditioner(second_coords)
349
350if self.switch:
351first_coords_inverted, second_coords = second_coords, first_coords_inverted
352
353return jnp.concatenate((first_coords_inverted, second_coords), axis=-1)
354
355
356class SymplecticLinearFlow(NormalizingFlow):
357"""Implements a SymplecticLinear layer from 'Learning Symmetries of Classical Integrable Systems', consisting of a shift, scale and rotation."""
358
359operation_input_dims: int
360
361def setup(self):
362num_dims = self.operation_input_dims
363self.shift_val = self.param('shift_val', nn.initializers.zeros, (num_dims,))
364self.scale_val = self.param('scale_val', nn.initializers.ones, (num_dims,))
365self.rotate_val = self.param('rotate_val', nn.initializers.zeros,
366(num_dims,))
367
368def extract_coords(self, inputs):
369"""Returns the two sets of coordinates."""
370num_dims = self.operation_input_dims
371assert inputs.shape[
372-1] == 2 * num_dims, f'Got inputs of shape {inputs.shape} for num_dims = {num_dims}.'
373
374first_coords = inputs[Ellipsis, :num_dims]
375second_coords = inputs[Ellipsis, num_dims:]
376return first_coords, second_coords
377
378def forward(self, inputs):
379"""Runs the forward pass."""
380first_coords, second_coords = self.extract_coords(inputs)
381first_coords, second_coords = self.shift(
382first_coords, second_coords, forward=True)
383first_coords, second_coords = self.scale(
384first_coords, second_coords, forward=True)
385first_coords, second_coords = self.rotate(
386first_coords, second_coords, forward=True)
387return jnp.concatenate((first_coords, second_coords), axis=-1)
388
389def inverse(self, inputs):
390"""Computes the inverse of self.forward()."""
391first_coords, second_coords = self.extract_coords(inputs)
392first_coords, second_coords = self.rotate(
393first_coords, second_coords, forward=False)
394first_coords, second_coords = self.scale(
395first_coords, second_coords, forward=False)
396first_coords, second_coords = self.shift(
397first_coords, second_coords, forward=False)
398return jnp.concatenate((first_coords, second_coords), axis=-1)
399
400def shift(self, first_coords, second_coords,
401forward):
402"""Performs the shift transformation on one of the coordinates."""
403shift = self.shift_val
404if not forward:
405shift = -shift
406first_coords_shifted = first_coords + second_coords * shift
407second_coords_shifted = second_coords
408return first_coords_shifted, second_coords_shifted
409
410def scale(self, first_coords, second_coords,
411forward):
412"""Scales the coordinates in a symplectic manner."""
413scale = self.scale_val
414if not forward:
415scale = 1 / scale
416first_coords_scaled = first_coords * scale
417second_coords_scaled = second_coords / scale
418return first_coords_scaled, second_coords_scaled
419
420def rotate(self, first_coords, second_coords,
421forward):
422"""Rotates all of the coordinates."""
423theta = self.rotate_val
424if not forward:
425theta = -theta
426first_coords_rotated = first_coords * jnp.cos(
427theta) - second_coords * jnp.sin(theta)
428second_coords_rotated = first_coords * jnp.sin(
429theta) + second_coords * jnp.cos(theta)
430return first_coords_rotated, second_coords_rotated
431
432
433class SequentialFlow(NormalizingFlow):
434"""Adaptation of nn.Sequential() for flows."""
435
436flows: Sequence[NormalizingFlow]
437
438def forward(self, inputs):
439for flow in self.flows:
440inputs = flow.forward(inputs)
441return inputs
442
443def inverse(self, inputs):
444for flow in reversed(self.flows):
445inputs = flow.inverse(inputs)
446return inputs
447
448def __call__(self, inputs):
449return self.forward(inputs)
450
451
452class CoordinateEncoder(abc.ABC, nn.Module):
453"""Base class for encoders."""
454
455@abc.abstractmethod
456def __call__(self, positions,
457momentums):
458"""Returns corresponding angles and momentums by encoding inputs."""
459
460
461class CoordinateDecoder(abc.ABC, nn.Module):
462"""Base class for decoders."""
463
464@abc.abstractmethod
465def __call__(self, actions,
466angles):
467"""Returns corresponding positions and momentums by decoding inputs."""
468
469
470class MLPEncoder(CoordinateEncoder):
471"""MLP-based encoder."""
472
473position_encoder: nn.Module
474momentum_encoder: nn.Module
475transform_fn: nn.Module
476latent_position_decoder: nn.Module
477latent_momentum_decoder: nn.Module
478
479def __call__(self, positions,
480momentums):
481# Encode input coordinates.
482positions = self.position_encoder(positions)
483momentums = self.momentum_encoder(momentums)
484
485# Transform to new coordinates.
486coords = jnp.concatenate([positions, momentums], axis=-1)
487coords = self.transform_fn(coords)
488
489# Decode to final coordinates.
490latent_positions = self.latent_position_decoder(coords)
491latent_momentums = self.latent_momentum_decoder(coords)
492return latent_positions, latent_momentums
493
494
495class MLPDecoder(CoordinateDecoder):
496"""MLP-based decoder."""
497
498latent_position_encoder: nn.Module
499latent_momentum_encoder: nn.Module
500transform_fn: nn.Module
501position_decoder: nn.Module
502momentum_decoder: nn.Module
503
504def __call__(self, latent_positions,
505latent_momentums):
506# Encode input coordinates.
507latent_positions = self.latent_position_encoder(latent_positions)
508latent_momentums = self.latent_momentum_encoder(latent_momentums)
509
510# Transform to new coordinates.
511coords = jnp.concatenate([latent_positions, latent_momentums], axis=-1)
512coords = self.transform_fn(coords)
513
514# Decode to final coordinates.
515positions = self.position_decoder(coords)
516momentums = self.momentum_decoder(coords)
517return positions, momentums
518
519
520class FlowEncoder(CoordinateEncoder):
521"""Flow-based encoder for the Action-Angle Neural Network."""
522
523flow: NormalizingFlow
524
525def __call__(self, positions,
526momentums):
527# Pass through forward flow to obtain latent positions and momentums.
528coords = jnp.concatenate([positions, momentums], axis=-1)
529coords = self.flow.forward(coords)
530
531assert len(coords.shape) == 2, coords.shape
532assert coords.shape[-1] % 2 == 0, coords.shape
533
534num_positions = coords.shape[-1] // 2
535latent_positions = coords[Ellipsis, :num_positions]
536latent_momentums = coords[Ellipsis, num_positions:]
537return latent_positions, latent_momentums
538
539
540class FlowDecoder(CoordinateDecoder):
541"""Flow-based decoder for the Action-Angle Neural Network."""
542
543flow: NormalizingFlow
544
545def __call__(self, latent_positions,
546latent_momentums):
547# Pass through inverse flow to obtain positions and momentums.
548coords = jnp.concatenate([latent_positions, latent_momentums], axis=-1)
549coords = self.flow.inverse(coords)
550
551assert len(coords.shape) == 2, coords.shape
552assert coords.shape[-1] % 2 == 0, coords.shape
553
554num_positions = coords.shape[-1] // 2
555positions = coords[Ellipsis, :num_positions]
556momentums = coords[Ellipsis, num_positions:]
557return positions, momentums
558
559
560class ActionAngleNetwork(nn.Module):
561"""Implementation of an Action-Angle Neural Network."""
562
563encoder: CoordinateEncoder
564angular_velocity_net: nn.Module
565decoder: CoordinateDecoder
566polar_action_angles: bool
567single_step_predictions: bool = True
568
569def predict_single_step(
570self, positions, momentums,
571time_deltas
572):
573"""Predicts future coordinates with one-step prediction."""
574time_deltas = jnp.squeeze(time_deltas)
575time_deltas = jnp.expand_dims(time_deltas, axis=range(time_deltas.ndim, 2))
576assert time_deltas.ndim == 2
577
578# Encode.
579current_latent_positions, current_latent_momentums = self.encoder(
580positions, momentums)
581if self.polar_action_angles:
582actions, current_angles = jax.vmap(cartesian_to_polar)(
583current_latent_positions, current_latent_momentums)
584else:
585actions, current_angles = current_latent_positions, current_latent_momentums
586
587# Compute angular velocities.
588angular_velocities = self.angular_velocity_net(actions)
589assert angular_velocities.shape[-1] == current_angles.shape[-1]
590
591# Fast-forward.
592future_angles = current_angles + angular_velocities * time_deltas
593if self.polar_action_angles:
594future_angles = (future_angles + jnp.pi) % (2 * jnp.pi) - (jnp.pi)
595
596# Decode.
597if self.polar_action_angles:
598future_latent_positions, future_latent_momentums = jax.vmap(
599polar_to_cartesian)(actions, future_angles)
600else:
601future_latent_positions, future_latent_momentums = actions, future_angles
602predicted_positions, predicted_momentums = self.decoder(
603future_latent_positions, future_latent_momentums)
604
605return predicted_positions, predicted_momentums, dict(
606current_latent_positions=current_latent_positions,
607current_latent_momentums=current_latent_momentums,
608actions=actions,
609current_angles=current_angles,
610angular_velocities=angular_velocities,
611future_angles=future_angles,
612future_latent_positions=future_latent_positions,
613future_latent_momentums=future_latent_momentums)
614
615def predict_multi_step(
616self, init_positions, init_momentums,
617time_deltas
618):
619"""Predicts future coordinates with multi-step prediction."""
620time_deltas = jnp.expand_dims(time_deltas, axis=range(time_deltas.ndim, 2))
621assert time_deltas.ndim == 2
622
623# init_positions and init_positions have shape [1 x num_trajectories].
624assert len(init_positions.shape) == 2
625assert len(init_momentums.shape) == 2
626assert init_positions.shape[0] == 1
627assert init_momentums.shape[0] == 1
628
629# Encode.
630current_latent_positions, current_latent_momentums = self.encoder(
631init_positions, init_momentums)
632if self.polar_action_angles:
633actions, current_angles = jax.vmap(cartesian_to_polar)(
634current_latent_positions, current_latent_momentums)
635else:
636actions, current_angles = current_latent_positions, current_latent_momentums
637
638# Compute angular velocities.
639angular_velocities = self.angular_velocity_net(actions)
640assert angular_velocities.shape[-1] == current_angles.shape[-1]
641
642# Fast-forward.
643future_angles = current_angles + angular_velocities * time_deltas
644
645# actions has shape [1 x num_trajectories].
646# future_angles has shape [T x num_trajectories].
647if self.polar_action_angles:
648future_angles = (future_angles + jnp.pi) % (2 * jnp.pi) - (jnp.pi)
649future_latent_positions, future_latent_momentums = jax.vmap(
650polar_to_cartesian, in_axes=(None, 0))(actions[0], future_angles)
651else:
652future_latent_positions, future_latent_momentums = jax.vmap(
653lambda x, y: (x, y), in_axes=(None, 0))(actions[0], future_angles)
654
655# predicted_positions has shape [T x num_trajectories].
656# predicted_momentums has shape [T x num_trajectories].
657predicted_positions, predicted_momentums = self.decoder(
658future_latent_positions, future_latent_momentums)
659
660return predicted_positions, predicted_momentums, dict(
661current_latent_positions=current_latent_positions,
662current_latent_momentums=current_latent_momentums,
663actions=actions,
664current_angles=current_angles,
665angular_velocities=angular_velocities,
666future_angles=future_angles,
667future_latent_positions=future_latent_positions,
668future_latent_momentums=future_latent_momentums)
669
670def encode_decode(self, positions,
671momentums):
672"""Encodes and decodes the given coordinates."""
673actions, current_angles = self.encoder(positions, momentums)
674return self.decoder(actions, current_angles)
675
676def __call__(
677self, positions, momentums,
678time_deltas
679):
680time_deltas = jnp.asarray(time_deltas)
681if self.single_step_predictions or time_deltas.ndim == 0:
682return self.predict_single_step(positions, momentums, time_deltas)
683return self.predict_multi_step(positions, momentums, time_deltas)
684
685
686class EulerUpdateNetwork(nn.Module):
687"""A neural network that performs Euler updates with predicted position and momentum derivatives."""
688
689encoder: CoordinateEncoder
690derivative_net: nn.Module
691decoder: CoordinateDecoder
692
693@nn.compact
694def __call__(self, positions, momentums,
695time_delta):
696# Encode.
697positions, momentums = self.encoder(positions, momentums)
698
699# Predict derivatives for each coordinate.
700coords = jnp.concatenate([positions, momentums], axis=-1)
701derivatives = self.derivative_net(coords)
702
703# Unpack.
704num_positions = derivatives.shape[-1] // 2
705position_derivative = derivatives[Ellipsis, :num_positions]
706momentum_derivative = derivatives[Ellipsis, num_positions:]
707
708# Perform Euler update.
709predicted_positions = positions + position_derivative * time_delta
710predicted_momentums = momentums + momentum_derivative * time_delta
711
712# Decode.
713predicted_positions, predicted_momentums = self.decoder(
714predicted_positions, predicted_momentums)
715return predicted_positions, predicted_momentums, None
716