google-research

Форк
0
/
Nested_R_hat.ipynb 
378 строк · 12.9 Кб
1
{
2
  "cells": [
3
    {
4
      "cell_type": "markdown",
5
      "metadata": {
6
        "id": "Mf3kOv1YMB5y"
7
      },
8
      "source": [
9
        "Copyright 2021 Google LLC."
10
      ]
11
    },
12
    {
13
      "cell_type": "code",
14
      "execution_count": null,
15
      "metadata": {
16
        "id": "-rOdskBSMfQN"
17
      },
18
      "outputs": [],
19
      "source": [
20
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
21
        "# you may not use this file except in compliance with the License.\n",
22
        "# You may obtain a copy of the License at\n",
23
        "#\n",
24
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
25
        "#\n",
26
        "# Unless required by applicable law or agreed to in writing, software\n",
27
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
28
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
29
        "# See the License for the specific language governing permissions and\n",
30
        "# limitations under the License."
31
      ]
32
    },
33
    {
34
      "cell_type": "code",
35
      "execution_count": null,
36
      "metadata": {
37
        "id": "wB8E6bWergh9"
38
      },
39
      "outputs": [],
40
      "source": [
41
        "import jax\n",
42
        "import jax.numpy as jnp\n",
43
        "import matplotlib.pyplot as plt\n",
44
        "\n",
45
        "from colabtools import adhoc_import\n",
46
        "from contextlib import ExitStack\n",
47
        "\n",
48
        "ADHOC = True\n",
49
        "CLIENT = 'fig-export-fig_tree-change-451-3e0a679e9746'\n",
50
        "\n",
51
        "import tensorflow_probability.substrates.jax as tfp\n",
52
        "from fun_mc import using_jax as fun_mc\n",
53
        "\n",
54
        "tfd = tfp.distributions"
55
      ]
56
    },
57
    {
58
      "cell_type": "markdown",
59
      "metadata": {
60
        "id": "qwbxRbDmFzOQ"
61
      },
62
      "source": [
63
        "# Variance analysis"
64
      ]
65
    },
66
    {
67
      "cell_type": "code",
68
      "execution_count": null,
69
      "metadata": {
70
        "id": "SHi-hjl2vtlb"
71
      },
72
      "outputs": [],
73
      "source": [
74
        "n_chains = 10240\n",
75
        "n_super_chains = 4\n",
76
        "n_steps = 100\n",
77
        "n_sub_chains = n_chains // n_super_chains\n",
78
        "\n",
79
        "# ESS by looking at raw chains\n",
80
        "ess_vals = []\n",
81
        "# ESS computed by looking at super chains\n",
82
        "pooled_ess_vals = []\n",
83
        "# Like pooled ESS, but we account for the number of sub chains used\n",
84
        "nested_ess_vals = []\n",
85
        "\n",
86
        "for seed in jax.random.split(jax.random.PRNGKey(0), 200):\n",
87
        "  chain = jax.random.normal(seed, [n_steps, n_chains])\n",
88
        "  pooled_chain = chain.reshape([n_steps, n_sub_chains, n_super_chains])\n",
89
        "\n",
90
        "  between = chain.mean(0).var(0, ddof=1)\n",
91
        "  overall = chain.var((0, 1), ddof=1)\n",
92
        "  ess_vals.append(overall / between)\n",
93
        "\n",
94
        "  super_chain = pooled_chain.mean(1)\n",
95
        "  pooled_between = super_chain.mean(0).var(0, ddof=1)\n",
96
        "  pooled_overall = super_chain.var((0, 1), ddof=1)\n",
97
        "  pooled_ess_vals.append(pooled_overall / pooled_between)\n",
98
        "\n",
99
        "  if True:\n",
100
        "    nested_between = pooled_chain.mean((0, 1)).var(0, ddof=1)\n",
101
        "    nested_overall = pooled_chain.var((0, 1, 2), ddof=1)\n",
102
        "    nested_ess_vals.append((nested_overall / nested_between))\n",
103
        "  else:\n",
104
        "    # Calculation from Charles's notebook.\n",
105
        "    mean_chain = pooled_chain.mean(0)\n",
106
        "    mean_super_chain = pooled_chain.mean((0, 1))\n",
107
        "    variance_chain = pooled_chain.var(0, ddof=1)\n",
108
        "    variance_nested_chain = mean_chain.var(0, ddof=1) + variance_chain.mean(0)\n",
109
        "\n",
110
        "    within_var = variance_nested_chain.mean(0)\n",
111
        "    between_var = mean_super_chain.var(0, ddof=1)\n",
112
        "\n",
113
        "    nested_ess_vals.append((1 + within_var / between_var))\n",
114
        "\n",
115
        "ess_vals = jnp.array(ess_vals)\n",
116
        "pooled_ess_vals = jnp.array(pooled_ess_vals)\n",
117
        "nested_ess_vals = jnp.array(nested_ess_vals)\n",
118
        "# We can also normalize the nested ESS values to take into account that super\n",
119
        "# chains are larger than regular chains. This interprets the pooling as an\n",
120
        "# ostensibly denoised ESS estimator.\n",
121
        "nested_normalized_ess_vals = nested_ess_vals / n_sub_chains\n",
122
        "\n",
123
        "# These are actually rhat - 1\n",
124
        "rhat_vals = 1 / ess_vals\n",
125
        "pooled_rhat_vals = 1 / pooled_ess_vals\n",
126
        "nested_rhat_vals = 1 / nested_ess_vals\n",
127
        "# It's not clear what the meaning of this is.\n",
128
        "nested_normalized_rhat_vals = 1 / nested_normalized_ess_vals\n",
129
        "\n",
130
        "# Expected per-chain ESS is n_steps\n",
131
        "print('ESS mean + std:', ess_vals.mean(), ess_vals.std())\n",
132
        "print('pooled ESS mean + std:', pooled_ess_vals.mean(), pooled_ess_vals.std())\n",
133
        "print('nested ESS mean + std:', nested_ess_vals.mean(), nested_ess_vals.std())\n",
134
        "print('nested normalized ESS mean + std:', nested_normalized_ess_vals.mean(), nested_normalized_ess_vals.std())\n",
135
        "print()\n",
136
        "print('rhat - 1 mean + std:', rhat_vals.mean(), rhat_vals.std())\n",
137
        "print('pooled rhat - 1 mean + std:', pooled_rhat_vals.mean(), pooled_rhat_vals.std())\n",
138
        "print('nested rhat - 1 mean + std:', nested_rhat_vals.mean(), nested_rhat_vals.std())\n",
139
        "print('nested normalized rhat - 1 mean + std:', nested_normalized_rhat_vals.mean(), nested_normalized_rhat_vals.std())\n",
140
        "\n",
141
        "fig = plt.figure(figsize=(24, 6))\n",
142
        "plt.subplot(2, 4, 1)\n",
143
        "plt.title('log10 ESS')\n",
144
        "plt.hist(jnp.log10(ess_vals), histtype='step', density=True, bins=50)\n",
145
        "\n",
146
        "plt.subplot(2, 4, 2)\n",
147
        "plt.title('log10 pooled ESS')\n",
148
        "plt.hist(jnp.log10(pooled_ess_vals), histtype='step', density=True, bins=50)\n",
149
        "\n",
150
        "plt.subplot(2, 4, 3)\n",
151
        "plt.title('log10 nested ESS')\n",
152
        "plt.hist(jnp.log10(nested_ess_vals), histtype='step', density=True, bins=50);\n",
153
        "\n",
154
        "plt.subplot(2, 4, 4)\n",
155
        "plt.title('log10 nested normalized ESS')\n",
156
        "plt.hist(jnp.log10(nested_normalized_ess_vals), histtype='step', density=True, bins=50);\n",
157
        "\n",
158
        "plt.subplot(2, 4, 5)\n",
159
        "plt.title('log10 rhat - 1')\n",
160
        "plt.hist(jnp.log10(rhat_vals), histtype='step', density=True, bins=50)\n",
161
        "\n",
162
        "plt.subplot(2, 4, 6)\n",
163
        "plt.title('log10 pooled rhat - 1')\n",
164
        "plt.hist(jnp.log10(pooled_rhat_vals), histtype='step', density=True, bins=50)\n",
165
        "\n",
166
        "plt.subplot(2, 4, 7)\n",
167
        "plt.title('log10 nested rhat - 1')\n",
168
        "plt.hist(jnp.log10(nested_rhat_vals), histtype='step', density=True, bins=50);\n",
169
        "\n",
170
        "plt.subplot(2, 4, 8)\n",
171
        "plt.title('log10 nested normalized rhat - 1')\n",
172
        "plt.hist(jnp.log10(nested_normalized_rhat_vals), histtype='step', density=True, bins=50);\n",
173
        "\n",
174
        "fig.tight_layout()"
175
      ]
176
    },
177
    {
178
      "cell_type": "markdown",
179
      "metadata": {
180
        "id": "7CzwZYfZF1jM"
181
      },
182
      "source": [
183
        "# MCMC test"
184
      ]
185
    },
186
    {
187
      "cell_type": "code",
188
      "execution_count": null,
189
      "metadata": {
190
        "id": "1j78wa-TVW3K"
191
      },
192
      "outputs": [],
193
      "source": [
194
        "dist = tfd.Normal(0., 1.)\n",
195
        "n_chains = 10240\n",
196
        "n_super_chains = 8\n",
197
        "n_steps = 100\n",
198
        "n_sub_chains = n_chains // n_super_chains\n",
199
        "\n",
200
        "def target_log_prob_fn(x):\n",
201
        "  return dist.log_prob(x), ()\n",
202
        "\n",
203
        "\n",
204
        "def kernel(hmc_state, seed):\n",
205
        "  hmc_seed, seed = jax.random.split(seed)\n",
206
        "  hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step(\n",
207
        "      hmc_state,\n",
208
        "      target_log_prob_fn=target_log_prob_fn,\n",
209
        "      step_size=0.5,\n",
210
        "      num_integrator_steps=1,\n",
211
        "      seed=hmc_seed)\n",
212
        "  return (hmc_state, seed), (hmc_state.state, hmc_extra.is_accepted)\n",
213
        "\n",
214
        "\n",
215
        "\n",
216
        "init_x = dist.sample([n_chains], seed=jax.random.PRNGKey(0))\n",
217
        "\n",
218
        "_, (chain, is_accepted) = fun_mc.trace((fun_mc.hamiltonian_monte_carlo_init(init_x,\n",
219
        "    target_log_prob_fn), jax.random.PRNGKey(0)), kernel, 10000)\n",
220
        "\n",
221
        "init_x2 = dist.sample([n_super_chains], seed=jax.random.PRNGKey(3))\n",
222
        "init_x2 = jnp.repeat(init_x2, n_sub_chains)\n",
223
        "#init_x2 = dist.sample([num_chains], seed=jax.random.PRNGKey(3))\n",
224
        "init_x2 = init_x2.reshape([n_super_chains, n_sub_chains])\n",
225
        "\n",
226
        "_, (chain2, is_accepted2) = fun_mc.trace((fun_mc.hamiltonian_monte_carlo_init(init_x2,\n",
227
        "    target_log_prob_fn), jax.random.PRNGKey(3)), kernel, 10000)\n",
228
        "\n",
229
        "chain = jnp.concatenate([init_x[jnp.newaxis], chain], 0)\n",
230
        "chain2 = jnp.concatenate([init_x2[jnp.newaxis], chain2], 0)"
231
      ]
232
    },
233
    {
234
      "cell_type": "code",
235
      "execution_count": null,
236
      "metadata": {
237
        "id": "8udM_46EsKaC"
238
      },
239
      "outputs": [],
240
      "source": [
241
        "plt.plot(chain[:, :4])"
242
      ]
243
    },
244
    {
245
      "cell_type": "code",
246
      "execution_count": null,
247
      "metadata": {
248
        "id": "QoMhF_n0tSQ2"
249
      },
250
      "outputs": [],
251
      "source": [
252
        "plt.plot(chain2[:, 0, :4])"
253
      ]
254
    },
255
    {
256
      "cell_type": "code",
257
      "execution_count": null,
258
      "metadata": {
259
        "id": "krTTXX9lziZV"
260
      },
261
      "outputs": [],
262
      "source": [
263
        "chain2[0].mean(-1).var(0), chain[0].var(0)"
264
      ]
265
    },
266
    {
267
      "cell_type": "code",
268
      "execution_count": null,
269
      "metadata": {
270
        "id": "I1oq9lixtXC1"
271
      },
272
      "outputs": [],
273
      "source": [
274
        "between_reg = (jnp.cumsum(chain, 0) / jnp.arange(1, chain.shape[0] + 1)[:, jnp.newaxis]).var(1)\n",
275
        "#between_reg = (jnp.cumsum(chain2, 0) / jnp.arange(1, chain.shape[0] + 1)[:, jnp.newaxis, jnp.newaxis]).var((1, 2))\n",
276
        "super_chain = chain2.mean(-1)\n",
277
        "between_nested = (jnp.cumsum(super_chain, 0) / jnp.arange(1, super_chain.shape[0] + 1)[:, jnp.newaxis]).var(1)"
278
      ]
279
    },
280
    {
281
      "cell_type": "code",
282
      "execution_count": null,
283
      "metadata": {
284
        "id": "irascTzltwiv"
285
      },
286
      "outputs": [],
287
      "source": [
288
        "plt.title('between chain variance')\n",
289
        "plt.plot(between_reg, label='regular chain')\n",
290
        "plt.plot(between_nested, label='super chain')\n",
291
        "plt.plot(between_reg / n_sub_chains, label='regular chain / n_sub_chains')\n",
292
        "plt.axhline(1e-2, ls='--', color='black', lw=2)\n",
293
        "plt.xscale('log')\n",
294
        "plt.yscale('log')\n",
295
        "plt.xlabel('chain length')\n",
296
        "plt.legend()"
297
      ]
298
    },
299
    {
300
      "cell_type": "code",
301
      "execution_count": null,
302
      "metadata": {
303
        "id": "C5LEwbtr_UHG"
304
      },
305
      "outputs": [],
306
      "source": [
307
        "between_reg2 = between_reg\n",
308
        "between_nested2 = between_nested\n",
309
        "n_sub_chains2 = n_sub_chains"
310
      ]
311
    },
312
    {
313
      "cell_type": "code",
314
      "execution_count": null,
315
      "metadata": {
316
        "id": "Y2Fp5rLNTxkZ"
317
      },
318
      "outputs": [],
319
      "source": [
320
        ""
321
      ]
322
    },
323
    {
324
      "cell_type": "code",
325
      "execution_count": null,
326
      "metadata": {
327
        "id": "cxG6z3UuSwW4"
328
      },
329
      "outputs": [],
330
      "source": [
331
        "print(n_sub_chains)\n",
332
        "print(n_sub_chains2)\n",
333
        "plt.figure(figsize=(12, 8))\n",
334
        "plt.title('between chain variance')\n",
335
        "\n",
336
        "plt.plot(between_reg, label='regular chain')\n",
337
        "plt.plot(between_nested, label='super chain')\n",
338
        "plt.plot(between_reg / n_sub_chains, label='regular chain / n_sub_chains')\n",
339
        "\n",
340
        "plt.plot(between_reg2, label='regular chain 2')\n",
341
        "plt.plot(between_nested2, label='super chain 2')\n",
342
        "plt.plot(between_reg2 / n_sub_chains2, label='regular chain 2 / n_sub_chains 2')\n",
343
        "\n",
344
        "plt.axhline(1e-2, ls='--', color='black', lw=2)\n",
345
        "plt.xscale('log')\n",
346
        "plt.yscale('log')\n",
347
        "plt.xlabel('chain length')\n",
348
        "plt.legend()"
349
      ]
350
    }
351
  ],
352
  "metadata": {
353
    "colab": {
354
      "collapsed_sections": [],
355
      "last_runtime": {
356
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
357
        "kind": "private"
358
      },
359
      "name": "Nested R-hat.ipynb",
360
      "private_outputs": true,
361
      "provenance": [
362
        {
363
          "file_id": "199pnTb5NJtFaLozNAlp3todlh4yjHD6g",
364
          "timestamp": 1632344135035
365
        }
366
      ]
367
    },
368
    "kernelspec": {
369
      "display_name": "Python 3",
370
      "name": "python3"
371
    },
372
    "language_info": {
373
      "name": "python"
374
    }
375
  },
376
  "nbformat": 4,
377
  "nbformat_minor": 0
378
}
379

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.