google-research
378 строк · 12.9 Кб
1{
2"cells": [
3{
4"cell_type": "markdown",
5"metadata": {
6"id": "Mf3kOv1YMB5y"
7},
8"source": [
9"Copyright 2021 Google LLC."
10]
11},
12{
13"cell_type": "code",
14"execution_count": null,
15"metadata": {
16"id": "-rOdskBSMfQN"
17},
18"outputs": [],
19"source": [
20"#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
21"# you may not use this file except in compliance with the License.\n",
22"# You may obtain a copy of the License at\n",
23"#\n",
24"# https://www.apache.org/licenses/LICENSE-2.0\n",
25"#\n",
26"# Unless required by applicable law or agreed to in writing, software\n",
27"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
28"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
29"# See the License for the specific language governing permissions and\n",
30"# limitations under the License."
31]
32},
33{
34"cell_type": "code",
35"execution_count": null,
36"metadata": {
37"id": "wB8E6bWergh9"
38},
39"outputs": [],
40"source": [
41"import jax\n",
42"import jax.numpy as jnp\n",
43"import matplotlib.pyplot as plt\n",
44"\n",
45"from colabtools import adhoc_import\n",
46"from contextlib import ExitStack\n",
47"\n",
48"ADHOC = True\n",
49"CLIENT = 'fig-export-fig_tree-change-451-3e0a679e9746'\n",
50"\n",
51"import tensorflow_probability.substrates.jax as tfp\n",
52"from fun_mc import using_jax as fun_mc\n",
53"\n",
54"tfd = tfp.distributions"
55]
56},
57{
58"cell_type": "markdown",
59"metadata": {
60"id": "qwbxRbDmFzOQ"
61},
62"source": [
63"# Variance analysis"
64]
65},
66{
67"cell_type": "code",
68"execution_count": null,
69"metadata": {
70"id": "SHi-hjl2vtlb"
71},
72"outputs": [],
73"source": [
74"n_chains = 10240\n",
75"n_super_chains = 4\n",
76"n_steps = 100\n",
77"n_sub_chains = n_chains // n_super_chains\n",
78"\n",
79"# ESS by looking at raw chains\n",
80"ess_vals = []\n",
81"# ESS computed by looking at super chains\n",
82"pooled_ess_vals = []\n",
83"# Like pooled ESS, but we account for the number of sub chains used\n",
84"nested_ess_vals = []\n",
85"\n",
86"for seed in jax.random.split(jax.random.PRNGKey(0), 200):\n",
87" chain = jax.random.normal(seed, [n_steps, n_chains])\n",
88" pooled_chain = chain.reshape([n_steps, n_sub_chains, n_super_chains])\n",
89"\n",
90" between = chain.mean(0).var(0, ddof=1)\n",
91" overall = chain.var((0, 1), ddof=1)\n",
92" ess_vals.append(overall / between)\n",
93"\n",
94" super_chain = pooled_chain.mean(1)\n",
95" pooled_between = super_chain.mean(0).var(0, ddof=1)\n",
96" pooled_overall = super_chain.var((0, 1), ddof=1)\n",
97" pooled_ess_vals.append(pooled_overall / pooled_between)\n",
98"\n",
99" if True:\n",
100" nested_between = pooled_chain.mean((0, 1)).var(0, ddof=1)\n",
101" nested_overall = pooled_chain.var((0, 1, 2), ddof=1)\n",
102" nested_ess_vals.append((nested_overall / nested_between))\n",
103" else:\n",
104" # Calculation from Charles's notebook.\n",
105" mean_chain = pooled_chain.mean(0)\n",
106" mean_super_chain = pooled_chain.mean((0, 1))\n",
107" variance_chain = pooled_chain.var(0, ddof=1)\n",
108" variance_nested_chain = mean_chain.var(0, ddof=1) + variance_chain.mean(0)\n",
109"\n",
110" within_var = variance_nested_chain.mean(0)\n",
111" between_var = mean_super_chain.var(0, ddof=1)\n",
112"\n",
113" nested_ess_vals.append((1 + within_var / between_var))\n",
114"\n",
115"ess_vals = jnp.array(ess_vals)\n",
116"pooled_ess_vals = jnp.array(pooled_ess_vals)\n",
117"nested_ess_vals = jnp.array(nested_ess_vals)\n",
118"# We can also normalize the nested ESS values to take into account that super\n",
119"# chains are larger than regular chains. This interprets the pooling as an\n",
120"# ostensibly denoised ESS estimator.\n",
121"nested_normalized_ess_vals = nested_ess_vals / n_sub_chains\n",
122"\n",
123"# These are actually rhat - 1\n",
124"rhat_vals = 1 / ess_vals\n",
125"pooled_rhat_vals = 1 / pooled_ess_vals\n",
126"nested_rhat_vals = 1 / nested_ess_vals\n",
127"# It's not clear what the meaning of this is.\n",
128"nested_normalized_rhat_vals = 1 / nested_normalized_ess_vals\n",
129"\n",
130"# Expected per-chain ESS is n_steps\n",
131"print('ESS mean + std:', ess_vals.mean(), ess_vals.std())\n",
132"print('pooled ESS mean + std:', pooled_ess_vals.mean(), pooled_ess_vals.std())\n",
133"print('nested ESS mean + std:', nested_ess_vals.mean(), nested_ess_vals.std())\n",
134"print('nested normalized ESS mean + std:', nested_normalized_ess_vals.mean(), nested_normalized_ess_vals.std())\n",
135"print()\n",
136"print('rhat - 1 mean + std:', rhat_vals.mean(), rhat_vals.std())\n",
137"print('pooled rhat - 1 mean + std:', pooled_rhat_vals.mean(), pooled_rhat_vals.std())\n",
138"print('nested rhat - 1 mean + std:', nested_rhat_vals.mean(), nested_rhat_vals.std())\n",
139"print('nested normalized rhat - 1 mean + std:', nested_normalized_rhat_vals.mean(), nested_normalized_rhat_vals.std())\n",
140"\n",
141"fig = plt.figure(figsize=(24, 6))\n",
142"plt.subplot(2, 4, 1)\n",
143"plt.title('log10 ESS')\n",
144"plt.hist(jnp.log10(ess_vals), histtype='step', density=True, bins=50)\n",
145"\n",
146"plt.subplot(2, 4, 2)\n",
147"plt.title('log10 pooled ESS')\n",
148"plt.hist(jnp.log10(pooled_ess_vals), histtype='step', density=True, bins=50)\n",
149"\n",
150"plt.subplot(2, 4, 3)\n",
151"plt.title('log10 nested ESS')\n",
152"plt.hist(jnp.log10(nested_ess_vals), histtype='step', density=True, bins=50);\n",
153"\n",
154"plt.subplot(2, 4, 4)\n",
155"plt.title('log10 nested normalized ESS')\n",
156"plt.hist(jnp.log10(nested_normalized_ess_vals), histtype='step', density=True, bins=50);\n",
157"\n",
158"plt.subplot(2, 4, 5)\n",
159"plt.title('log10 rhat - 1')\n",
160"plt.hist(jnp.log10(rhat_vals), histtype='step', density=True, bins=50)\n",
161"\n",
162"plt.subplot(2, 4, 6)\n",
163"plt.title('log10 pooled rhat - 1')\n",
164"plt.hist(jnp.log10(pooled_rhat_vals), histtype='step', density=True, bins=50)\n",
165"\n",
166"plt.subplot(2, 4, 7)\n",
167"plt.title('log10 nested rhat - 1')\n",
168"plt.hist(jnp.log10(nested_rhat_vals), histtype='step', density=True, bins=50);\n",
169"\n",
170"plt.subplot(2, 4, 8)\n",
171"plt.title('log10 nested normalized rhat - 1')\n",
172"plt.hist(jnp.log10(nested_normalized_rhat_vals), histtype='step', density=True, bins=50);\n",
173"\n",
174"fig.tight_layout()"
175]
176},
177{
178"cell_type": "markdown",
179"metadata": {
180"id": "7CzwZYfZF1jM"
181},
182"source": [
183"# MCMC test"
184]
185},
186{
187"cell_type": "code",
188"execution_count": null,
189"metadata": {
190"id": "1j78wa-TVW3K"
191},
192"outputs": [],
193"source": [
194"dist = tfd.Normal(0., 1.)\n",
195"n_chains = 10240\n",
196"n_super_chains = 8\n",
197"n_steps = 100\n",
198"n_sub_chains = n_chains // n_super_chains\n",
199"\n",
200"def target_log_prob_fn(x):\n",
201" return dist.log_prob(x), ()\n",
202"\n",
203"\n",
204"def kernel(hmc_state, seed):\n",
205" hmc_seed, seed = jax.random.split(seed)\n",
206" hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step(\n",
207" hmc_state,\n",
208" target_log_prob_fn=target_log_prob_fn,\n",
209" step_size=0.5,\n",
210" num_integrator_steps=1,\n",
211" seed=hmc_seed)\n",
212" return (hmc_state, seed), (hmc_state.state, hmc_extra.is_accepted)\n",
213"\n",
214"\n",
215"\n",
216"init_x = dist.sample([n_chains], seed=jax.random.PRNGKey(0))\n",
217"\n",
218"_, (chain, is_accepted) = fun_mc.trace((fun_mc.hamiltonian_monte_carlo_init(init_x,\n",
219" target_log_prob_fn), jax.random.PRNGKey(0)), kernel, 10000)\n",
220"\n",
221"init_x2 = dist.sample([n_super_chains], seed=jax.random.PRNGKey(3))\n",
222"init_x2 = jnp.repeat(init_x2, n_sub_chains)\n",
223"#init_x2 = dist.sample([num_chains], seed=jax.random.PRNGKey(3))\n",
224"init_x2 = init_x2.reshape([n_super_chains, n_sub_chains])\n",
225"\n",
226"_, (chain2, is_accepted2) = fun_mc.trace((fun_mc.hamiltonian_monte_carlo_init(init_x2,\n",
227" target_log_prob_fn), jax.random.PRNGKey(3)), kernel, 10000)\n",
228"\n",
229"chain = jnp.concatenate([init_x[jnp.newaxis], chain], 0)\n",
230"chain2 = jnp.concatenate([init_x2[jnp.newaxis], chain2], 0)"
231]
232},
233{
234"cell_type": "code",
235"execution_count": null,
236"metadata": {
237"id": "8udM_46EsKaC"
238},
239"outputs": [],
240"source": [
241"plt.plot(chain[:, :4])"
242]
243},
244{
245"cell_type": "code",
246"execution_count": null,
247"metadata": {
248"id": "QoMhF_n0tSQ2"
249},
250"outputs": [],
251"source": [
252"plt.plot(chain2[:, 0, :4])"
253]
254},
255{
256"cell_type": "code",
257"execution_count": null,
258"metadata": {
259"id": "krTTXX9lziZV"
260},
261"outputs": [],
262"source": [
263"chain2[0].mean(-1).var(0), chain[0].var(0)"
264]
265},
266{
267"cell_type": "code",
268"execution_count": null,
269"metadata": {
270"id": "I1oq9lixtXC1"
271},
272"outputs": [],
273"source": [
274"between_reg = (jnp.cumsum(chain, 0) / jnp.arange(1, chain.shape[0] + 1)[:, jnp.newaxis]).var(1)\n",
275"#between_reg = (jnp.cumsum(chain2, 0) / jnp.arange(1, chain.shape[0] + 1)[:, jnp.newaxis, jnp.newaxis]).var((1, 2))\n",
276"super_chain = chain2.mean(-1)\n",
277"between_nested = (jnp.cumsum(super_chain, 0) / jnp.arange(1, super_chain.shape[0] + 1)[:, jnp.newaxis]).var(1)"
278]
279},
280{
281"cell_type": "code",
282"execution_count": null,
283"metadata": {
284"id": "irascTzltwiv"
285},
286"outputs": [],
287"source": [
288"plt.title('between chain variance')\n",
289"plt.plot(between_reg, label='regular chain')\n",
290"plt.plot(between_nested, label='super chain')\n",
291"plt.plot(between_reg / n_sub_chains, label='regular chain / n_sub_chains')\n",
292"plt.axhline(1e-2, ls='--', color='black', lw=2)\n",
293"plt.xscale('log')\n",
294"plt.yscale('log')\n",
295"plt.xlabel('chain length')\n",
296"plt.legend()"
297]
298},
299{
300"cell_type": "code",
301"execution_count": null,
302"metadata": {
303"id": "C5LEwbtr_UHG"
304},
305"outputs": [],
306"source": [
307"between_reg2 = between_reg\n",
308"between_nested2 = between_nested\n",
309"n_sub_chains2 = n_sub_chains"
310]
311},
312{
313"cell_type": "code",
314"execution_count": null,
315"metadata": {
316"id": "Y2Fp5rLNTxkZ"
317},
318"outputs": [],
319"source": [
320""
321]
322},
323{
324"cell_type": "code",
325"execution_count": null,
326"metadata": {
327"id": "cxG6z3UuSwW4"
328},
329"outputs": [],
330"source": [
331"print(n_sub_chains)\n",
332"print(n_sub_chains2)\n",
333"plt.figure(figsize=(12, 8))\n",
334"plt.title('between chain variance')\n",
335"\n",
336"plt.plot(between_reg, label='regular chain')\n",
337"plt.plot(between_nested, label='super chain')\n",
338"plt.plot(between_reg / n_sub_chains, label='regular chain / n_sub_chains')\n",
339"\n",
340"plt.plot(between_reg2, label='regular chain 2')\n",
341"plt.plot(between_nested2, label='super chain 2')\n",
342"plt.plot(between_reg2 / n_sub_chains2, label='regular chain 2 / n_sub_chains 2')\n",
343"\n",
344"plt.axhline(1e-2, ls='--', color='black', lw=2)\n",
345"plt.xscale('log')\n",
346"plt.yscale('log')\n",
347"plt.xlabel('chain length')\n",
348"plt.legend()"
349]
350}
351],
352"metadata": {
353"colab": {
354"collapsed_sections": [],
355"last_runtime": {
356"build_target": "//learning/deepmind/dm_python:dm_notebook3",
357"kind": "private"
358},
359"name": "Nested R-hat.ipynb",
360"private_outputs": true,
361"provenance": [
362{
363"file_id": "199pnTb5NJtFaLozNAlp3todlh4yjHD6g",
364"timestamp": 1632344135035
365}
366]
367},
368"kernelspec": {
369"display_name": "Python 3",
370"name": "python3"
371},
372"language_info": {
373"name": "python"
374}
375},
376"nbformat": 4,
377"nbformat_minor": 0
378}
379