google-research
181 строка · 6.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Run surrogate posterior benchmarks."""
17import json18import os19import pathlib20import time21
22from absl import app23from absl import flags24from inference_gym import using_tensorflow as inference_gym25import matplotlib.pyplot as plt26import tensorflow as tf27import tensorflow_probability as tfp28
29
30from automatic_structured_vi import make_surrogate_posteriors31
32
33gfile = tf.io.gfile34
35FLAGS = flags.FLAGS36
37flags.DEFINE_enum('model_name', 'brownian_motion', [38'brownian_motion', 'stochastic_volatility', 'radon', 'eight_schools',39'lorenz_bridge', 'lorenz_bridge_global', 'brownian_motion_global'40], 'Inference Gym model')41flags.DEFINE_enum('posterior_type', 'asvi', [42'asvi', 'large_iaf', 'small_iaf', 'maf', 'mean_field', 'mvn',43'autoregressive'44], 'Type of surrogate posterior to use.')45flags.DEFINE_integer('num_steps', 100000, 'Number of optimization steps')46flags.DEFINE_float('learning_rate', 1e-2, 'Optimizer learning rate')47flags.DEFINE_float('prior_weight', 0.5, 'Initialization value of prior_weight.')48flags.DEFINE_integer('ensemble_num', 0, 'Ensemble member ID.')49flags.DEFINE_string('output_dir',50os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'run_vi/'),51'Directory to store output files.')52
53
54def main(_):55model_name = FLAGS.model_name56num_steps = FLAGS.num_steps57learning_rate = FLAGS.learning_rate58posterior_type = FLAGS.posterior_type59prior_weight = FLAGS.prior_weight60xid = FLAGS.xm_xid if hasattr(FLAGS, 'xm_xid') else -161
62
63output_dir = '{}_xid{}/{}'.format(FLAGS.output_dir, xid,64wid) if xid > -1 else FLAGS.output_dir65output_dir = pathlib.Path(output_dir)66gfile.makedirs(output_dir)67
68if model_name == 'brownian_motion':69model = inference_gym.targets.BrownianMotionMissingMiddleObservations()70elif model_name == 'stochastic_volatility':71model = inference_gym.targets.StochasticVolatilitySP500Small()72elif model_name == 'eight_schools':73model = inference_gym.targets.EightSchools()74elif model_name == 'lorenz_bridge':75model = inference_gym.targets.ConvectionLorenzBridge()76elif model_name == 'lorenz_bridge_global':77model = inference_gym.targets.ConvectionLorenzBridgeUnknownScales()78elif model_name == 'brownian_motion_global':79model = inference_gym.targets.BrownianMotionUnknownScalesMissingMiddleObservations(80)81elif model_name == 'radon':82model = inference_gym.targets.RadonContextualEffectsHalfNormalMinnesota(83dtype=tf.float32)84
85else:86raise NotImplementedError(87'"{}" is not a valid value for `model_name`'.format(model_name))88
89prior = model.prior_distribution()90if isinstance(prior.event_shape, dict):91target_log_prob = lambda **values: model.log_likelihood( # pylint: disable=g-long-lambda92values) + prior.log_prob(values)93else:94target_log_prob = lambda *values: model.log_likelihood( # pylint: disable=g-long-lambda95values) + prior.log_prob(values)96
97opt = tf.optimizers.Adam(learning_rate)98
99if posterior_type == 'asvi':100surrogate_dist = tfp.experimental.vi.build_asvi_surrogate_posterior(101prior, initial_prior_weight=prior_weight)102elif posterior_type == 'mean_field':103surrogate_dist = tfp.experimental.vi.build_asvi_surrogate_posterior(104prior, mean_field=True)105elif posterior_type == 'large_iaf':106surrogate_dist = make_surrogate_posteriors.make_flow_posterior(107prior, num_hidden_units=512, invert=True)108elif posterior_type == 'small_iaf':109surrogate_dist = make_surrogate_posteriors.make_flow_posterior(110prior, num_hidden_units=8, invert=True)111elif posterior_type == 'maf':112surrogate_dist = make_surrogate_posteriors.make_flow_posterior(113prior, num_hidden_units=512, invert=False)114elif posterior_type == 'mvn':115surrogate_dist = make_surrogate_posteriors.make_mvn_posterior(prior)116elif posterior_type == 'autoregressive':117surrogate_dist = make_surrogate_posteriors.build_autoregressive_surrogate_posterior(118prior, make_surrogate_posteriors.make_conditional_linear_gaussian)119
120@tf.function(experimental_compile=False)121def fit_vi():122return tfp.vi.fit_surrogate_posterior(123target_log_prob,124surrogate_dist,125optimizer=opt,126num_steps=num_steps)127
128start = time.time()129losses = fit_vi()130trace_run_time = time.time() - start131
132# Actual Run Time133start = time.time()134fit_vi()135run_time = time.time() - start136
137losses = losses.numpy()138posterior_samples = surrogate_dist.sample(100)139
140samples = surrogate_dist.sample(1000)141
142if isinstance(prior.event_shape, dict):143final_elbo = tf.reduce_mean(144target_log_prob(**samples)145- surrogate_dist.log_prob(samples)).numpy().tolist()146else:147final_elbo = tf.reduce_mean(148target_log_prob(*samples)149- surrogate_dist.log_prob(samples)).numpy().tolist()150
151json_output = {152'losses': losses.tolist(),153'trace_time': trace_run_time - run_time,154'run_time': run_time,155'num_steps': num_steps,156'final_elbo': final_elbo,157'learning_rate': learning_rate,158'ensemble_num': FLAGS.ensemble_num,159'xm_xid': str(xid)160}161
162
163fig, ax = plt.subplots()164ax.plot(losses)165ax.set_xlabel('Iterations')166ax.set_ylabel('Loss')167with tf.io.gfile.GFile(168os.path.join(output_dir, 'loss_plot.png'), 'w') as fp:169fig.savefig(fp)170
171with tf.io.gfile.GFile(172os.path.join(output_dir, 'results.json'), 'w') as out_file:173json.dump(json_output, out_file)174
175with tf.io.gfile.GFile(176os.path.join(output_dir, 'samples.json'), 'w') as out_file:177json.dump(tf.nest.map_structure(178lambda x: x.numpy().tolist(), posterior_samples), out_file)179
180if __name__ == '__main__':181app.run(main)182