google-research

Форк
0
181 строка · 6.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Run surrogate posterior benchmarks."""
17
import json
18
import os
19
import pathlib
20
import time
21

22
from absl import app
23
from absl import flags
24
from inference_gym import using_tensorflow as inference_gym
25
import matplotlib.pyplot as plt
26
import tensorflow as tf
27
import tensorflow_probability as tfp
28

29

30
from automatic_structured_vi import make_surrogate_posteriors
31

32

33
gfile = tf.io.gfile
34

35
FLAGS = flags.FLAGS
36

37
flags.DEFINE_enum('model_name', 'brownian_motion', [
38
    'brownian_motion', 'stochastic_volatility', 'radon', 'eight_schools',
39
    'lorenz_bridge', 'lorenz_bridge_global', 'brownian_motion_global'
40
], 'Inference Gym model')
41
flags.DEFINE_enum('posterior_type', 'asvi', [
42
    'asvi', 'large_iaf', 'small_iaf', 'maf', 'mean_field', 'mvn',
43
    'autoregressive'
44
], 'Type of surrogate posterior to use.')
45
flags.DEFINE_integer('num_steps', 100000, 'Number of optimization steps')
46
flags.DEFINE_float('learning_rate', 1e-2, 'Optimizer learning rate')
47
flags.DEFINE_float('prior_weight', 0.5, 'Initialization value of prior_weight.')
48
flags.DEFINE_integer('ensemble_num', 0, 'Ensemble member ID.')
49
flags.DEFINE_string('output_dir',
50
                    os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'run_vi/'),
51
                    'Directory to store output files.')
52

53

54
def main(_):
55
  model_name = FLAGS.model_name
56
  num_steps = FLAGS.num_steps
57
  learning_rate = FLAGS.learning_rate
58
  posterior_type = FLAGS.posterior_type
59
  prior_weight = FLAGS.prior_weight
60
  xid = FLAGS.xm_xid if hasattr(FLAGS, 'xm_xid') else -1
61

62

63
  output_dir = '{}_xid{}/{}'.format(FLAGS.output_dir, xid,
64
                                    wid) if xid > -1 else FLAGS.output_dir
65
  output_dir = pathlib.Path(output_dir)
66
  gfile.makedirs(output_dir)
67

68
  if model_name == 'brownian_motion':
69
    model = inference_gym.targets.BrownianMotionMissingMiddleObservations()
70
  elif model_name == 'stochastic_volatility':
71
    model = inference_gym.targets.StochasticVolatilitySP500Small()
72
  elif model_name == 'eight_schools':
73
    model = inference_gym.targets.EightSchools()
74
  elif model_name == 'lorenz_bridge':
75
    model = inference_gym.targets.ConvectionLorenzBridge()
76
  elif model_name == 'lorenz_bridge_global':
77
    model = inference_gym.targets.ConvectionLorenzBridgeUnknownScales()
78
  elif model_name == 'brownian_motion_global':
79
    model = inference_gym.targets.BrownianMotionUnknownScalesMissingMiddleObservations(
80
    )
81
  elif model_name == 'radon':
82
    model = inference_gym.targets.RadonContextualEffectsHalfNormalMinnesota(
83
        dtype=tf.float32)
84

85
  else:
86
    raise NotImplementedError(
87
        '"{}" is not a valid value for `model_name`'.format(model_name))
88

89
  prior = model.prior_distribution()
90
  if isinstance(prior.event_shape, dict):
91
    target_log_prob = lambda **values: model.log_likelihood(  # pylint: disable=g-long-lambda
92
        values) + prior.log_prob(values)
93
  else:
94
    target_log_prob = lambda *values: model.log_likelihood(  # pylint: disable=g-long-lambda
95
        values) + prior.log_prob(values)
96

97
  opt = tf.optimizers.Adam(learning_rate)
98

99
  if posterior_type == 'asvi':
100
    surrogate_dist = tfp.experimental.vi.build_asvi_surrogate_posterior(
101
        prior, initial_prior_weight=prior_weight)
102
  elif posterior_type == 'mean_field':
103
    surrogate_dist = tfp.experimental.vi.build_asvi_surrogate_posterior(
104
        prior, mean_field=True)
105
  elif posterior_type == 'large_iaf':
106
    surrogate_dist = make_surrogate_posteriors.make_flow_posterior(
107
        prior, num_hidden_units=512, invert=True)
108
  elif posterior_type == 'small_iaf':
109
    surrogate_dist = make_surrogate_posteriors.make_flow_posterior(
110
        prior, num_hidden_units=8, invert=True)
111
  elif posterior_type == 'maf':
112
    surrogate_dist = make_surrogate_posteriors.make_flow_posterior(
113
        prior, num_hidden_units=512, invert=False)
114
  elif posterior_type == 'mvn':
115
    surrogate_dist = make_surrogate_posteriors.make_mvn_posterior(prior)
116
  elif posterior_type == 'autoregressive':
117
    surrogate_dist = make_surrogate_posteriors.build_autoregressive_surrogate_posterior(
118
        prior, make_surrogate_posteriors.make_conditional_linear_gaussian)
119

120
  @tf.function(experimental_compile=False)
121
  def fit_vi():
122
    return tfp.vi.fit_surrogate_posterior(
123
        target_log_prob,
124
        surrogate_dist,
125
        optimizer=opt,
126
        num_steps=num_steps)
127

128
  start = time.time()
129
  losses = fit_vi()
130
  trace_run_time = time.time() - start
131

132
  # Actual Run Time
133
  start = time.time()
134
  fit_vi()
135
  run_time = time.time() - start
136

137
  losses = losses.numpy()
138
  posterior_samples = surrogate_dist.sample(100)
139

140
  samples = surrogate_dist.sample(1000)
141

142
  if isinstance(prior.event_shape, dict):
143
    final_elbo = tf.reduce_mean(
144
        target_log_prob(**samples)
145
        - surrogate_dist.log_prob(samples)).numpy().tolist()
146
  else:
147
    final_elbo = tf.reduce_mean(
148
        target_log_prob(*samples)
149
        - surrogate_dist.log_prob(samples)).numpy().tolist()
150

151
  json_output = {
152
      'losses': losses.tolist(),
153
      'trace_time': trace_run_time - run_time,
154
      'run_time': run_time,
155
      'num_steps': num_steps,
156
      'final_elbo': final_elbo,
157
      'learning_rate': learning_rate,
158
      'ensemble_num': FLAGS.ensemble_num,
159
      'xm_xid': str(xid)
160
  }
161

162

163
  fig, ax = plt.subplots()
164
  ax.plot(losses)
165
  ax.set_xlabel('Iterations')
166
  ax.set_ylabel('Loss')
167
  with tf.io.gfile.GFile(
168
      os.path.join(output_dir, 'loss_plot.png'), 'w') as fp:
169
    fig.savefig(fp)
170

171
  with tf.io.gfile.GFile(
172
      os.path.join(output_dir, 'results.json'), 'w') as out_file:
173
    json.dump(json_output, out_file)
174

175
  with tf.io.gfile.GFile(
176
      os.path.join(output_dir, 'samples.json'), 'w') as out_file:
177
    json.dump(tf.nest.map_structure(
178
        lambda x: x.numpy().tolist(), posterior_samples), out_file)
179

180
if __name__ == '__main__':
181
  app.run(main)
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.