pytorch
171 строка · 5.9 Кб
1r"""
2The ``distributions`` package contains parameterizable probability distributions
3and sampling functions. This allows the construction of stochastic computation
4graphs and stochastic gradient estimators for optimization. This package
5generally follows the design of the `TensorFlow Distributions`_ package.
6
7.. _`TensorFlow Distributions`:
8https://arxiv.org/abs/1711.10604
9
10It is not possible to directly backpropagate through random samples. However,
11there are two main methods for creating surrogate functions that can be
12backpropagated through. These are the score function estimator/likelihood ratio
13estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
14seen as the basis for policy gradient methods in reinforcement learning, and the
15pathwise derivative estimator is commonly seen in the reparameterization trick
16in variational autoencoders. Whilst the score function only requires the value
17of samples :math:`f(x)`, the pathwise derivative requires the derivative
18:math:`f'(x)`. The next sections discuss these two in a reinforcement learning
19example. For more details see
20`Gradient Estimation Using Stochastic Computation Graphs`_ .
21
22.. _`Gradient Estimation Using Stochastic Computation Graphs`:
23https://arxiv.org/abs/1506.05254
24
25Score function
26^^^^^^^^^^^^^^
27
28When the probability density function is differentiable with respect to its
29parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
30:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
31
32.. math::
33
34\Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
35
36where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
37:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
38taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
39
40In practice we would sample an action from the output of a network, apply this
41action in an environment, and then use ``log_prob`` to construct an equivalent
42loss function. Note that we use a negative because optimizers use gradient
43descent, whilst the rule above assumes gradient ascent. With a categorical
44policy, the code for implementing REINFORCE would be as follows::
45
46probs = policy_network(state)
47# Note that this is equivalent to what used to be called multinomial
48m = Categorical(probs)
49action = m.sample()
50next_state, reward = env.step(action)
51loss = -m.log_prob(action) * reward
52loss.backward()
53
54Pathwise derivative
55^^^^^^^^^^^^^^^^^^^
56
57The other way to implement these stochastic/policy gradients would be to use the
58reparameterization trick from the
59:meth:`~torch.distributions.Distribution.rsample` method, where the
60parameterized random variable can be constructed via a parameterized
61deterministic function of a parameter-free random variable. The reparameterized
62sample therefore becomes differentiable. The code for implementing the pathwise
63derivative would be as follows::
64
65params = policy_network(state)
66m = Normal(*params)
67# Any distribution with .has_rsample == True could work based on the application
68action = m.rsample()
69next_state, reward = env.step(action) # Assuming that reward is differentiable
70loss = -reward
71loss.backward()
72"""
73
74from .bernoulli import Bernoulli75from .beta import Beta76from .binomial import Binomial77from .categorical import Categorical78from .cauchy import Cauchy79from .chi2 import Chi280from .constraint_registry import biject_to, transform_to81from .continuous_bernoulli import ContinuousBernoulli82from .dirichlet import Dirichlet83from .distribution import Distribution84from .exp_family import ExponentialFamily85from .exponential import Exponential86from .fishersnedecor import FisherSnedecor87from .gamma import Gamma88from .geometric import Geometric89from .gumbel import Gumbel90from .half_cauchy import HalfCauchy91from .half_normal import HalfNormal92from .independent import Independent93from .inverse_gamma import InverseGamma94from .kl import _add_kl_info, kl_divergence, register_kl95from .kumaraswamy import Kumaraswamy96from .laplace import Laplace97from .lkj_cholesky import LKJCholesky98from .log_normal import LogNormal99from .logistic_normal import LogisticNormal100from .lowrank_multivariate_normal import LowRankMultivariateNormal101from .mixture_same_family import MixtureSameFamily102from .multinomial import Multinomial103from .multivariate_normal import MultivariateNormal104from .negative_binomial import NegativeBinomial105from .normal import Normal106from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough107from .pareto import Pareto108from .poisson import Poisson109from .relaxed_bernoulli import RelaxedBernoulli110from .relaxed_categorical import RelaxedOneHotCategorical111from .studentT import StudentT112from .transformed_distribution import TransformedDistribution113from .transforms import * # noqa: F403114from . import transforms115from .uniform import Uniform116from .von_mises import VonMises117from .weibull import Weibull118from .wishart import Wishart119
120_add_kl_info()121del _add_kl_info122
123__all__ = [124"Bernoulli",125"Beta",126"Binomial",127"Categorical",128"Cauchy",129"Chi2",130"ContinuousBernoulli",131"Dirichlet",132"Distribution",133"Exponential",134"ExponentialFamily",135"FisherSnedecor",136"Gamma",137"Geometric",138"Gumbel",139"HalfCauchy",140"HalfNormal",141"Independent",142"InverseGamma",143"Kumaraswamy",144"LKJCholesky",145"Laplace",146"LogNormal",147"LogisticNormal",148"LowRankMultivariateNormal",149"MixtureSameFamily",150"Multinomial",151"MultivariateNormal",152"NegativeBinomial",153"Normal",154"OneHotCategorical",155"OneHotCategoricalStraightThrough",156"Pareto",157"RelaxedBernoulli",158"RelaxedOneHotCategorical",159"StudentT",160"Poisson",161"Uniform",162"VonMises",163"Weibull",164"Wishart",165"TransformedDistribution",166"biject_to",167"kl_divergence",168"register_kl",169"transform_to",170]
171__all__.extend(transforms.__all__)172