pytorch
83 строки · 2.7 Кб
1# mypy: allow-untyped-defs
2import math3from numbers import Number4
5import torch6from torch.distributions import constraints7from torch.distributions.transformed_distribution import TransformedDistribution8from torch.distributions.transforms import AffineTransform, ExpTransform9from torch.distributions.uniform import Uniform10from torch.distributions.utils import broadcast_all, euler_constant11
12
13__all__ = ["Gumbel"]14
15
16class Gumbel(TransformedDistribution):17r"""18Samples from a Gumbel Distribution.
19
20Examples::
21
22>>> # xdoctest: +IGNORE_WANT("non-deterministic")
23>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
24>>> m.sample() # sample from Gumbel distribution with loc=1, scale=2
25tensor([ 1.0124])
26
27Args:
28loc (float or Tensor): Location parameter of the distribution
29scale (float or Tensor): Scale parameter of the distribution
30"""
31arg_constraints = {"loc": constraints.real, "scale": constraints.positive}32support = constraints.real33
34def __init__(self, loc, scale, validate_args=None):35self.loc, self.scale = broadcast_all(loc, scale)36finfo = torch.finfo(self.loc.dtype)37if isinstance(loc, Number) and isinstance(scale, Number):38base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)39else:40base_dist = Uniform(41torch.full_like(self.loc, finfo.tiny),42torch.full_like(self.loc, 1 - finfo.eps),43validate_args=validate_args,44)45transforms = [46ExpTransform().inv,47AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),48ExpTransform().inv,49AffineTransform(loc=loc, scale=-self.scale),50]51super().__init__(base_dist, transforms, validate_args=validate_args)52
53def expand(self, batch_shape, _instance=None):54new = self._get_checked_instance(Gumbel, _instance)55new.loc = self.loc.expand(batch_shape)56new.scale = self.scale.expand(batch_shape)57return super().expand(batch_shape, _instance=new)58
59# Explicitly defining the log probability function for Gumbel due to precision issues60def log_prob(self, value):61if self._validate_args:62self._validate_sample(value)63y = (self.loc - value) / self.scale64return (y - y.exp()) - self.scale.log()65
66@property67def mean(self):68return self.loc + self.scale * euler_constant69
70@property71def mode(self):72return self.loc73
74@property75def stddev(self):76return (math.pi / math.sqrt(6)) * self.scale77
78@property79def variance(self):80return self.stddev.pow(2)81
82def entropy(self):83return self.scale.log() + (1 + euler_constant)84