pytorch
102 строки · 3.2 Кб
1# mypy: allow-untyped-defs
2from numbers import Number3
4import torch5from torch import nan6from torch.distributions import constraints7from torch.distributions.distribution import Distribution8from torch.distributions.utils import broadcast_all9from torch.types import _size10
11
12__all__ = ["Uniform"]13
14
15class Uniform(Distribution):16r"""17Generates uniformly distributed random samples from the half-open interval
18``[low, high)``.
19
20Example::
21
22>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
23>>> m.sample() # uniformly distributed in the range [0.0, 5.0)
24>>> # xdoctest: +SKIP
25tensor([ 2.3418])
26
27Args:
28low (float or Tensor): lower range (inclusive).
29high (float or Tensor): upper range (exclusive).
30"""
31# TODO allow (loc,scale) parameterization to allow independent constraints.32arg_constraints = {33"low": constraints.dependent(is_discrete=False, event_dim=0),34"high": constraints.dependent(is_discrete=False, event_dim=0),35}36has_rsample = True37
38@property39def mean(self):40return (self.high + self.low) / 241
42@property43def mode(self):44return nan * self.high45
46@property47def stddev(self):48return (self.high - self.low) / 12**0.549
50@property51def variance(self):52return (self.high - self.low).pow(2) / 1253
54def __init__(self, low, high, validate_args=None):55self.low, self.high = broadcast_all(low, high)56
57if isinstance(low, Number) and isinstance(high, Number):58batch_shape = torch.Size()59else:60batch_shape = self.low.size()61super().__init__(batch_shape, validate_args=validate_args)62
63if self._validate_args and not torch.lt(self.low, self.high).all():64raise ValueError("Uniform is not defined when low>= high")65
66def expand(self, batch_shape, _instance=None):67new = self._get_checked_instance(Uniform, _instance)68batch_shape = torch.Size(batch_shape)69new.low = self.low.expand(batch_shape)70new.high = self.high.expand(batch_shape)71super(Uniform, new).__init__(batch_shape, validate_args=False)72new._validate_args = self._validate_args73return new74
75@constraints.dependent_property(is_discrete=False, event_dim=0)76def support(self):77return constraints.interval(self.low, self.high)78
79def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:80shape = self._extended_shape(sample_shape)81rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device)82return self.low + rand * (self.high - self.low)83
84def log_prob(self, value):85if self._validate_args:86self._validate_sample(value)87lb = self.low.le(value).type_as(self.low)88ub = self.high.gt(value).type_as(self.low)89return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)90
91def cdf(self, value):92if self._validate_args:93self._validate_sample(value)94result = (value - self.low) / (self.high - self.low)95return result.clamp(min=0, max=1)96
97def icdf(self, value):98result = value * (self.high - self.low) + self.low99return result100
101def entropy(self):102return torch.log(self.high - self.low)103