google-research
286 строк · 10.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Construct different types of surrogate posteriors for VI."""
17import collections
18
19import numpy as np
20import tensorflow as tf
21import tensorflow_probability as tfp
22
23from tensorflow_probability.python.internal import prefer_static as ps
24
25tfb = tfp.bijectors
26tfd = tfp.distributions
27tfp_util = tfp.util
28
29LinearGaussianVariables = collections.namedtuple('LinearGaussianVariables',
30['matrix', 'loc', 'scale'])
31
32
33def make_flow_posterior(prior,
34num_hidden_units,
35invert=True,
36num_flow_layers=2):
37"""Make a MAF/IAF surrogate posterior.
38
39Args:
40prior: tfd.JointDistribution instance of the prior.
41num_hidden_units: int value. Specifies the number of hidden units.
42invert: Optional Boolean value. If `True`, produces inverse autoregressive
43flow. If `False`, produces a masked autoregressive flow.
44Default value: `True`.
45num_flow_layers: Optional int value. Specifies the number of layers.
46Returns:
47surrogate_posterior: A `tfd.TransformedDistribution` instance
48whose samples have shape and structure matching that of `prior`.
49"""
50
51event_shape = prior.event_shape_tensor()
52event_space_bijector = prior.experimental_default_event_space_bijector()
53flat_event_shape = tf.nest.flatten(event_shape)
54flat_event_size = [
55tf.reduce_prod(s) for s in flat_event_shape]
56
57ndims = tf.reduce_sum(flat_event_size)
58dtype = tf.nest.flatten(prior.dtype)[0]
59
60make_swap = lambda: tfb.Permute(ps.range(ndims - 1, -1, -1))
61def make_maf():
62net = tfb.AutoregressiveNetwork(
632,
64hidden_units=[num_hidden_units, num_hidden_units],
65activation=tf.tanh,
66dtype=dtype)
67
68maf = tfb.MaskedAutoregressiveFlow(
69bijector_fn=lambda x: tfb.Chain([tfb.Shift(net(x)[Ellipsis, 0]), # pylint: disable=g-long-lambda
70tfb.Scale(log_scale=net(x)[Ellipsis, 1])]))
71if invert:
72maf = tfb.Invert(maf)
73# To track the variables
74maf._net = net # pylint: disable=protected-access
75return maf
76
77dist = tfd.Sample(
78tfd.Normal(tf.zeros([], dtype=dtype), 1.), sample_shape=[ndims])
79
80bijectors = [
81event_space_bijector,
82tfb.Restructure(
83tf.nest.pack_sequence_as(event_shape, range(len(flat_event_shape)))),
84tfb.JointMap(tf.nest.map_structure(tfb.Reshape, flat_event_shape)),
85tfb.Split(flat_event_size),
86]
87bijectors.append(make_maf())
88
89for _ in range(num_flow_layers - 1):
90bijectors.extend([make_swap(), make_maf()])
91
92return tfd.TransformedDistribution(dist, tfb.Chain(bijectors))
93
94
95def make_mvn_posterior(prior):
96"""Build a Multivariate Normal (MVN) posterior.
97
98Args:
99prior: tfd.JointDistribution instance of the prior.
100Returns:
101surrogate_posterior: A `tfd.TransformedDistribution` instance
102whose samples have shape and structure matching that of `prior`.
103"""
104
105event_shape = prior.event_shape_tensor()
106event_space_bijector = prior.experimental_default_event_space_bijector()
107flat_event_shape = tf.nest.flatten(event_shape)
108flat_event_size = [
109tf.reduce_prod(s) for s in flat_event_shape]
110
111ndims = tf.reduce_sum(flat_event_size)
112
113dtype = tf.nest.flatten(prior.dtype)[0]
114
115base_dist = tfd.Sample(
116tfd.Normal(tf.zeros([], dtype), 1.), sample_shape=[ndims])
117op = make_trainable_linear_operator_tril(ndims)
118
119bijectors = [
120event_space_bijector,
121tfb.Restructure(
122tf.nest.pack_sequence_as(event_shape, range(len(flat_event_shape)))),
123tfb.JointMap(tf.nest.map_structure(tfb.Reshape, flat_event_shape)),
124tfb.Split(flat_event_size),
125tfb.Shift(tf.Variable(tf.zeros([ndims], dtype=dtype))),
126tfb.ScaleMatvecLinearOperator(op)]
127return tfd.TransformedDistribution(base_dist, tfb.Chain(bijectors))
128
129
130def make_trainable_linear_operator_tril(
131dim,
132scale_initializer=1e-1,
133diag_bijector=None,
134diag_shift=1e-5,
135dtype=tf.float32):
136"""Build a trainable lower triangular linop."""
137scale_tril_bijector = tfb.FillScaleTriL(
138diag_bijector, diag_shift=diag_shift)
139flat_initial_scale = tf.zeros((dim * (dim + 1) // 2,), dtype=dtype)
140initial_scale_tril = tfb.FillScaleTriL(
141diag_bijector=tfb.Identity(), diag_shift=scale_initializer)(
142flat_initial_scale)
143return tf.linalg.LinearOperatorLowerTriangular(
144tril=tfp_util.TransformedVariable(
145initial_scale_tril, bijector=scale_tril_bijector))
146
147
148def build_autoregressive_surrogate_posterior(prior, make_conditional_dist_fn):
149"""Build a chain-structured surrogate posterior.
150
151Args:
152prior: JointDistribution instance.
153make_conditional_dist_fn: callable with signature `dist, variables =
154make_conditional_dist_fn(event_shape, x, x_event_shape, variables=None)`
155that builds and returns a trainable distribution over unconstrained
156values, with the specific event shape, conditioned on an input `x`. If
157'variables' is not passed, the necessary variables should be created and
158returned. Passing the returned `variables` structure to future calls
159should replicate the same conditional distribution.
160Returns:
161surrogate_posterior: A `tfd.JointDistributionCoroutineAutoBatched` instance
162whose samples have shape and structure matching that of `prior`.
163"""
164with tf.name_scope('build_autoregressive_surrogate_posterior'):
165
166Root = tfd.JointDistributionCoroutine.Root # pylint: disable=invalid-name
167trainable_variables = []
168
169def posterior_generator():
170prior_gen = prior._model_coroutine() # pylint: disable=protected-access
171
172previous_value = None
173previous_event_ndims = 0
174previous_dist_was_global = True
175
176dist = next(prior_gen)
177
178i = 0
179try:
180while True:
181actual_dist = dist.distribution if isinstance(dist, Root) else dist
182event_shape = actual_dist.event_shape_tensor()
183
184# Keep global variables out of the chain.
185if previous_dist_was_global:
186previous_value = np.array(0., dtype=np.float32)
187previous_event_ndims = 0
188
189unconstrained_surrogate, dist_variables = make_conditional_dist_fn(
190y_event_shape=event_shape,
191x=previous_value,
192x_event_ndims=previous_event_ndims,
193variables=(trainable_variables[i]
194if len(trainable_variables) > i else None))
195# If this is the first run, save the created variables to reuse later.
196if len(trainable_variables) <= i:
197trainable_variables.append(dist_variables)
198
199surrogate_dist = (
200actual_dist.experimental_default_event_space_bijector()(
201unconstrained_surrogate))
202
203if previous_dist_was_global:
204value_out = yield Root(surrogate_dist)
205else:
206value_out = yield surrogate_dist
207
208previous_value = value_out
209previous_event_ndims = ps.rank_from_shape(event_shape)
210previous_dist_was_global = isinstance(dist, Root)
211
212dist = prior_gen.send(value_out)
213i += 1
214except StopIteration:
215pass
216
217surrogate_posterior = tfd.JointDistributionCoroutine(posterior_generator)
218
219# Build variables.
220_ = surrogate_posterior.sample()
221
222surrogate_posterior._also_track = trainable_variables # pylint: disable=protected-access
223return surrogate_posterior
224
225
226def make_conditional_linear_gaussian(y_event_shape,
227x,
228x_event_ndims,
229variables=None):
230"""Build trainable distribution `p(y | x)` conditioned on an input Tensor `x`.
231
232The distribution is independent Gaussian with mean linearly transformed
233from `x`:
234`y ~ N(loc=matvec(matrix, x) + loc, scale_diag=scale)`
235
236Args:
237y_event_shape: int `Tensor` event shape.
238x: `Tensor` input to condition on.
239x_event_ndims: int number of dimensions in `x`'s `event_shape`.
240variables: Optional `LinearGaussianVariables` instance, or `None`.
241Default value: `None`.
242
243Returns:
244dist: Instance of `tfd.Distribution` representing the conditional
245distribution `p(y | x)`.
246variables: Instance of `LinearGaussianVariables` used to parameterize
247`dist`. If a `variables` arg was passed, it is returned unmodified;
248otherwise new variables are created.
249"""
250x_shape = ps.shape(x)
251x_ndims = ps.rank_from_shape(x_shape)
252y_event_ndims = ps.rank_from_shape(y_event_shape)
253batch_shape, x_event_shape = (x_shape[:x_ndims - x_event_ndims],
254x_shape[x_ndims - x_event_ndims:])
255
256x_event_size = ps.reduce_prod(x_event_shape)
257y_event_size = ps.reduce_prod(y_event_shape)
258
259x_flat_shape = ps.concat([batch_shape, [x_event_size]], axis=0)
260y_flat_shape = ps.concat([batch_shape, [y_event_size]], axis=0)
261y_full_shape = ps.concat([batch_shape, y_event_shape], axis=0)
262
263if variables is None:
264variables = LinearGaussianVariables(
265matrix=tf.Variable(
266tf.random.normal(
267ps.concat([batch_shape, [y_event_size, x_event_size]], axis=0),
268dtype=x.dtype),
269name='matrix'),
270loc=tf.Variable(
271tf.random.normal(y_flat_shape, dtype=x.dtype), name='loc'),
272scale=tfp_util.TransformedVariable(
273tf.ones(y_full_shape, dtype=x.dtype),
274bijector=tfb.Softplus(),
275name='scale'))
276
277flat_x = tf.reshape(x, x_flat_shape)
278dist = tfd.Normal(
279loc=tf.reshape(
280tf.linalg.matvec(variables.matrix, flat_x) + variables.loc,
281y_full_shape),
282scale=variables.scale)
283if y_event_ndims != 0:
284dist = tfd.Independent(dist, reinterpreted_batch_ndims=y_event_ndims)
285dist._also_track = variables # pylint: disable=protected-access
286return dist, variables
287
288