google-research

Форк
0
/
make_surrogate_posteriors.py 
286 строк · 10.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Construct different types of surrogate posteriors for VI."""
17
import collections
18

19
import numpy as np
20
import tensorflow as tf
21
import tensorflow_probability as tfp
22

23
from tensorflow_probability.python.internal import prefer_static as ps
24

25
tfb = tfp.bijectors
26
tfd = tfp.distributions
27
tfp_util = tfp.util
28

29
LinearGaussianVariables = collections.namedtuple('LinearGaussianVariables',
30
                                                 ['matrix', 'loc', 'scale'])
31

32

33
def make_flow_posterior(prior,
34
                        num_hidden_units,
35
                        invert=True,
36
                        num_flow_layers=2):
37
  """Make a MAF/IAF surrogate posterior.
38

39
  Args:
40
    prior: tfd.JointDistribution instance of the prior.
41
    num_hidden_units: int value. Specifies the number of hidden units.
42
    invert: Optional Boolean value. If `True`, produces inverse autoregressive
43
      flow. If `False`, produces a masked autoregressive flow.
44
      Default value: `True`.
45
    num_flow_layers: Optional int value. Specifies the number of layers.
46
  Returns:
47
    surrogate_posterior: A `tfd.TransformedDistribution` instance
48
      whose samples have shape and structure matching that of `prior`.
49
  """
50

51
  event_shape = prior.event_shape_tensor()
52
  event_space_bijector = prior.experimental_default_event_space_bijector()
53
  flat_event_shape = tf.nest.flatten(event_shape)
54
  flat_event_size = [
55
      tf.reduce_prod(s) for s in flat_event_shape]
56

57
  ndims = tf.reduce_sum(flat_event_size)
58
  dtype = tf.nest.flatten(prior.dtype)[0]
59

60
  make_swap = lambda: tfb.Permute(ps.range(ndims - 1, -1, -1))
61
  def make_maf():
62
    net = tfb.AutoregressiveNetwork(
63
        2,
64
        hidden_units=[num_hidden_units, num_hidden_units],
65
        activation=tf.tanh,
66
        dtype=dtype)
67

68
    maf = tfb.MaskedAutoregressiveFlow(
69
        bijector_fn=lambda x: tfb.Chain([tfb.Shift(net(x)[Ellipsis, 0]),  # pylint: disable=g-long-lambda
70
                                         tfb.Scale(log_scale=net(x)[Ellipsis, 1])]))
71
    if invert:
72
      maf = tfb.Invert(maf)
73
    # To track the variables
74
    maf._net = net  # pylint: disable=protected-access
75
    return maf
76

77
  dist = tfd.Sample(
78
      tfd.Normal(tf.zeros([], dtype=dtype), 1.), sample_shape=[ndims])
79

80
  bijectors = [
81
      event_space_bijector,
82
      tfb.Restructure(
83
          tf.nest.pack_sequence_as(event_shape, range(len(flat_event_shape)))),
84
      tfb.JointMap(tf.nest.map_structure(tfb.Reshape, flat_event_shape)),
85
      tfb.Split(flat_event_size),
86
      ]
87
  bijectors.append(make_maf())
88

89
  for _ in range(num_flow_layers - 1):
90
    bijectors.extend([make_swap(), make_maf()])
91

92
  return tfd.TransformedDistribution(dist, tfb.Chain(bijectors))
93

94

95
def make_mvn_posterior(prior):
96
  """Build a Multivariate Normal (MVN) posterior.
97

98
  Args:
99
    prior: tfd.JointDistribution instance of the prior.
100
  Returns:
101
    surrogate_posterior: A `tfd.TransformedDistribution` instance
102
    whose samples have shape and structure matching that of `prior`.
103
  """
104

105
  event_shape = prior.event_shape_tensor()
106
  event_space_bijector = prior.experimental_default_event_space_bijector()
107
  flat_event_shape = tf.nest.flatten(event_shape)
108
  flat_event_size = [
109
      tf.reduce_prod(s) for s in flat_event_shape]
110

111
  ndims = tf.reduce_sum(flat_event_size)
112

113
  dtype = tf.nest.flatten(prior.dtype)[0]
114

115
  base_dist = tfd.Sample(
116
      tfd.Normal(tf.zeros([], dtype), 1.), sample_shape=[ndims])
117
  op = make_trainable_linear_operator_tril(ndims)
118

119
  bijectors = [
120
      event_space_bijector,
121
      tfb.Restructure(
122
          tf.nest.pack_sequence_as(event_shape, range(len(flat_event_shape)))),
123
      tfb.JointMap(tf.nest.map_structure(tfb.Reshape, flat_event_shape)),
124
      tfb.Split(flat_event_size),
125
      tfb.Shift(tf.Variable(tf.zeros([ndims], dtype=dtype))),
126
      tfb.ScaleMatvecLinearOperator(op)]
127
  return tfd.TransformedDistribution(base_dist, tfb.Chain(bijectors))
128

129

130
def make_trainable_linear_operator_tril(
131
    dim,
132
    scale_initializer=1e-1,
133
    diag_bijector=None,
134
    diag_shift=1e-5,
135
    dtype=tf.float32):
136
  """Build a trainable lower triangular linop."""
137
  scale_tril_bijector = tfb.FillScaleTriL(
138
      diag_bijector, diag_shift=diag_shift)
139
  flat_initial_scale = tf.zeros((dim * (dim + 1) // 2,), dtype=dtype)
140
  initial_scale_tril = tfb.FillScaleTriL(
141
      diag_bijector=tfb.Identity(), diag_shift=scale_initializer)(
142
          flat_initial_scale)
143
  return tf.linalg.LinearOperatorLowerTriangular(
144
      tril=tfp_util.TransformedVariable(
145
          initial_scale_tril, bijector=scale_tril_bijector))
146

147

148
def build_autoregressive_surrogate_posterior(prior, make_conditional_dist_fn):
149
  """Build a chain-structured surrogate posterior.
150

151
  Args:
152
    prior: JointDistribution instance.
153
    make_conditional_dist_fn: callable with signature `dist, variables =
154
      make_conditional_dist_fn(event_shape, x, x_event_shape, variables=None)`
155
      that builds and returns a trainable distribution over unconstrained
156
      values, with the specific event shape, conditioned on an input `x`. If
157
      'variables' is not passed, the necessary variables should be created and
158
      returned. Passing the returned `variables` structure to future calls
159
      should replicate the same conditional distribution.
160
  Returns:
161
    surrogate_posterior: A `tfd.JointDistributionCoroutineAutoBatched` instance
162
    whose samples have shape and structure matching that of `prior`.
163
  """
164
  with tf.name_scope('build_autoregressive_surrogate_posterior'):
165

166
    Root = tfd.JointDistributionCoroutine.Root  # pylint: disable=invalid-name
167
    trainable_variables = []
168

169
    def posterior_generator():
170
      prior_gen = prior._model_coroutine()  # pylint: disable=protected-access
171

172
      previous_value = None
173
      previous_event_ndims = 0
174
      previous_dist_was_global = True
175

176
      dist = next(prior_gen)
177

178
      i = 0
179
      try:
180
        while True:
181
          actual_dist = dist.distribution if isinstance(dist, Root) else dist
182
          event_shape = actual_dist.event_shape_tensor()
183

184
          # Keep global variables out of the chain.
185
          if previous_dist_was_global:
186
            previous_value = np.array(0., dtype=np.float32)
187
            previous_event_ndims = 0
188

189
          unconstrained_surrogate, dist_variables = make_conditional_dist_fn(
190
              y_event_shape=event_shape,
191
              x=previous_value,
192
              x_event_ndims=previous_event_ndims,
193
              variables=(trainable_variables[i]
194
                         if len(trainable_variables) > i else None))
195
          # If this is the first run, save the created variables to reuse later.
196
          if len(trainable_variables) <= i:
197
            trainable_variables.append(dist_variables)
198

199
          surrogate_dist = (
200
              actual_dist.experimental_default_event_space_bijector()(
201
                  unconstrained_surrogate))
202

203
          if previous_dist_was_global:
204
            value_out = yield Root(surrogate_dist)
205
          else:
206
            value_out = yield surrogate_dist
207

208
          previous_value = value_out
209
          previous_event_ndims = ps.rank_from_shape(event_shape)
210
          previous_dist_was_global = isinstance(dist, Root)
211

212
          dist = prior_gen.send(value_out)
213
          i += 1
214
      except StopIteration:
215
        pass
216

217
    surrogate_posterior = tfd.JointDistributionCoroutine(posterior_generator)
218

219
    # Build variables.
220
    _ = surrogate_posterior.sample()
221

222
    surrogate_posterior._also_track = trainable_variables  # pylint: disable=protected-access
223
    return surrogate_posterior
224

225

226
def make_conditional_linear_gaussian(y_event_shape,
227
                                     x,
228
                                     x_event_ndims,
229
                                     variables=None):
230
  """Build trainable distribution `p(y | x)` conditioned on an input Tensor `x`.
231

232
  The distribution is independent Gaussian with mean linearly transformed
233
  from `x`:
234
  `y ~ N(loc=matvec(matrix, x) + loc, scale_diag=scale)`
235

236
  Args:
237
    y_event_shape: int `Tensor` event shape.
238
    x: `Tensor` input to condition on.
239
    x_event_ndims: int number of dimensions in `x`'s `event_shape`.
240
    variables: Optional `LinearGaussianVariables` instance, or `None`.
241
      Default value: `None`.
242

243
  Returns:
244
    dist: Instance of `tfd.Distribution` representing the conditional
245
      distribution `p(y | x)`.
246
    variables: Instance of `LinearGaussianVariables` used to parameterize
247
      `dist`. If a `variables` arg was passed, it is returned unmodified;
248
      otherwise new variables are created.
249
  """
250
  x_shape = ps.shape(x)
251
  x_ndims = ps.rank_from_shape(x_shape)
252
  y_event_ndims = ps.rank_from_shape(y_event_shape)
253
  batch_shape, x_event_shape = (x_shape[:x_ndims - x_event_ndims],
254
                                x_shape[x_ndims - x_event_ndims:])
255

256
  x_event_size = ps.reduce_prod(x_event_shape)
257
  y_event_size = ps.reduce_prod(y_event_shape)
258

259
  x_flat_shape = ps.concat([batch_shape, [x_event_size]], axis=0)
260
  y_flat_shape = ps.concat([batch_shape, [y_event_size]], axis=0)
261
  y_full_shape = ps.concat([batch_shape, y_event_shape], axis=0)
262

263
  if variables is None:
264
    variables = LinearGaussianVariables(
265
        matrix=tf.Variable(
266
            tf.random.normal(
267
                ps.concat([batch_shape, [y_event_size, x_event_size]], axis=0),
268
                dtype=x.dtype),
269
            name='matrix'),
270
        loc=tf.Variable(
271
            tf.random.normal(y_flat_shape, dtype=x.dtype), name='loc'),
272
        scale=tfp_util.TransformedVariable(
273
            tf.ones(y_full_shape, dtype=x.dtype),
274
            bijector=tfb.Softplus(),
275
            name='scale'))
276

277
  flat_x = tf.reshape(x, x_flat_shape)
278
  dist = tfd.Normal(
279
      loc=tf.reshape(
280
          tf.linalg.matvec(variables.matrix, flat_x) + variables.loc,
281
          y_full_shape),
282
      scale=variables.scale)
283
  if y_event_ndims != 0:
284
    dist = tfd.Independent(dist, reinterpreted_batch_ndims=y_event_ndims)
285
  dist._also_track = variables  # pylint: disable=protected-access
286
  return dist, variables
287

288

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.