pytorch

Форк
0
171 строка · 5.9 Кб
1
r"""
2
The ``distributions`` package contains parameterizable probability distributions
3
and sampling functions. This allows the construction of stochastic computation
4
graphs and stochastic gradient estimators for optimization. This package
5
generally follows the design of the `TensorFlow Distributions`_ package.
6

7
.. _`TensorFlow Distributions`:
8
    https://arxiv.org/abs/1711.10604
9

10
It is not possible to directly backpropagate through random samples. However,
11
there are two main methods for creating surrogate functions that can be
12
backpropagated through. These are the score function estimator/likelihood ratio
13
estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
14
seen as the basis for policy gradient methods in reinforcement learning, and the
15
pathwise derivative estimator is commonly seen in the reparameterization trick
16
in variational autoencoders. Whilst the score function only requires the value
17
of samples :math:`f(x)`, the pathwise derivative requires the derivative
18
:math:`f'(x)`. The next sections discuss these two in a reinforcement learning
19
example. For more details see
20
`Gradient Estimation Using Stochastic Computation Graphs`_ .
21

22
.. _`Gradient Estimation Using Stochastic Computation Graphs`:
23
     https://arxiv.org/abs/1506.05254
24

25
Score function
26
^^^^^^^^^^^^^^
27

28
When the probability density function is differentiable with respect to its
29
parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
30
:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
31

32
.. math::
33

34
    \Delta\theta  = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
35

36
where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
37
:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
38
taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
39

40
In practice we would sample an action from the output of a network, apply this
41
action in an environment, and then use ``log_prob`` to construct an equivalent
42
loss function. Note that we use a negative because optimizers use gradient
43
descent, whilst the rule above assumes gradient ascent. With a categorical
44
policy, the code for implementing REINFORCE would be as follows::
45

46
    probs = policy_network(state)
47
    # Note that this is equivalent to what used to be called multinomial
48
    m = Categorical(probs)
49
    action = m.sample()
50
    next_state, reward = env.step(action)
51
    loss = -m.log_prob(action) * reward
52
    loss.backward()
53

54
Pathwise derivative
55
^^^^^^^^^^^^^^^^^^^
56

57
The other way to implement these stochastic/policy gradients would be to use the
58
reparameterization trick from the
59
:meth:`~torch.distributions.Distribution.rsample` method, where the
60
parameterized random variable can be constructed via a parameterized
61
deterministic function of a parameter-free random variable. The reparameterized
62
sample therefore becomes differentiable. The code for implementing the pathwise
63
derivative would be as follows::
64

65
    params = policy_network(state)
66
    m = Normal(*params)
67
    # Any distribution with .has_rsample == True could work based on the application
68
    action = m.rsample()
69
    next_state, reward = env.step(action)  # Assuming that reward is differentiable
70
    loss = -reward
71
    loss.backward()
72
"""
73

74
from .bernoulli import Bernoulli
75
from .beta import Beta
76
from .binomial import Binomial
77
from .categorical import Categorical
78
from .cauchy import Cauchy
79
from .chi2 import Chi2
80
from .constraint_registry import biject_to, transform_to
81
from .continuous_bernoulli import ContinuousBernoulli
82
from .dirichlet import Dirichlet
83
from .distribution import Distribution
84
from .exp_family import ExponentialFamily
85
from .exponential import Exponential
86
from .fishersnedecor import FisherSnedecor
87
from .gamma import Gamma
88
from .geometric import Geometric
89
from .gumbel import Gumbel
90
from .half_cauchy import HalfCauchy
91
from .half_normal import HalfNormal
92
from .independent import Independent
93
from .inverse_gamma import InverseGamma
94
from .kl import _add_kl_info, kl_divergence, register_kl
95
from .kumaraswamy import Kumaraswamy
96
from .laplace import Laplace
97
from .lkj_cholesky import LKJCholesky
98
from .log_normal import LogNormal
99
from .logistic_normal import LogisticNormal
100
from .lowrank_multivariate_normal import LowRankMultivariateNormal
101
from .mixture_same_family import MixtureSameFamily
102
from .multinomial import Multinomial
103
from .multivariate_normal import MultivariateNormal
104
from .negative_binomial import NegativeBinomial
105
from .normal import Normal
106
from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
107
from .pareto import Pareto
108
from .poisson import Poisson
109
from .relaxed_bernoulli import RelaxedBernoulli
110
from .relaxed_categorical import RelaxedOneHotCategorical
111
from .studentT import StudentT
112
from .transformed_distribution import TransformedDistribution
113
from .transforms import *  # noqa: F403
114
from . import transforms
115
from .uniform import Uniform
116
from .von_mises import VonMises
117
from .weibull import Weibull
118
from .wishart import Wishart
119

120
_add_kl_info()
121
del _add_kl_info
122

123
__all__ = [
124
    "Bernoulli",
125
    "Beta",
126
    "Binomial",
127
    "Categorical",
128
    "Cauchy",
129
    "Chi2",
130
    "ContinuousBernoulli",
131
    "Dirichlet",
132
    "Distribution",
133
    "Exponential",
134
    "ExponentialFamily",
135
    "FisherSnedecor",
136
    "Gamma",
137
    "Geometric",
138
    "Gumbel",
139
    "HalfCauchy",
140
    "HalfNormal",
141
    "Independent",
142
    "InverseGamma",
143
    "Kumaraswamy",
144
    "LKJCholesky",
145
    "Laplace",
146
    "LogNormal",
147
    "LogisticNormal",
148
    "LowRankMultivariateNormal",
149
    "MixtureSameFamily",
150
    "Multinomial",
151
    "MultivariateNormal",
152
    "NegativeBinomial",
153
    "Normal",
154
    "OneHotCategorical",
155
    "OneHotCategoricalStraightThrough",
156
    "Pareto",
157
    "RelaxedBernoulli",
158
    "RelaxedOneHotCategorical",
159
    "StudentT",
160
    "Poisson",
161
    "Uniform",
162
    "VonMises",
163
    "Weibull",
164
    "Wishart",
165
    "TransformedDistribution",
166
    "biject_to",
167
    "kl_divergence",
168
    "register_kl",
169
    "transform_to",
170
]
171
__all__.extend(transforms.__all__)
172

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.