pytorch

Форк
0
/
independent.py 
125 строк · 4.5 Кб
1
from typing import Dict
2

3
import torch
4
from torch.distributions import constraints
5
from torch.distributions.distribution import Distribution
6
from torch.distributions.utils import _sum_rightmost
7

8
__all__ = ["Independent"]
9

10

11
class Independent(Distribution):
12
    r"""
13
    Reinterprets some of the batch dims of a distribution as event dims.
14

15
    This is mainly useful for changing the shape of the result of
16
    :meth:`log_prob`. For example to create a diagonal Normal distribution with
17
    the same shape as a Multivariate Normal distribution (so they are
18
    interchangeable), you can::
19

20
        >>> from torch.distributions.multivariate_normal import MultivariateNormal
21
        >>> from torch.distributions.normal import Normal
22
        >>> loc = torch.zeros(3)
23
        >>> scale = torch.ones(3)
24
        >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
25
        >>> [mvn.batch_shape, mvn.event_shape]
26
        [torch.Size([]), torch.Size([3])]
27
        >>> normal = Normal(loc, scale)
28
        >>> [normal.batch_shape, normal.event_shape]
29
        [torch.Size([3]), torch.Size([])]
30
        >>> diagn = Independent(normal, 1)
31
        >>> [diagn.batch_shape, diagn.event_shape]
32
        [torch.Size([]), torch.Size([3])]
33

34
    Args:
35
        base_distribution (torch.distributions.distribution.Distribution): a
36
            base distribution
37
        reinterpreted_batch_ndims (int): the number of batch dims to
38
            reinterpret as event dims
39
    """
40
    arg_constraints: Dict[str, constraints.Constraint] = {}
41

42
    def __init__(
43
        self, base_distribution, reinterpreted_batch_ndims, validate_args=None
44
    ):
45
        if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
46
            raise ValueError(
47
                "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
48
                f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
49
            )
50
        shape = base_distribution.batch_shape + base_distribution.event_shape
51
        event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
52
        batch_shape = shape[: len(shape) - event_dim]
53
        event_shape = shape[len(shape) - event_dim :]
54
        self.base_dist = base_distribution
55
        self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
56
        super().__init__(batch_shape, event_shape, validate_args=validate_args)
57

58
    def expand(self, batch_shape, _instance=None):
59
        new = self._get_checked_instance(Independent, _instance)
60
        batch_shape = torch.Size(batch_shape)
61
        new.base_dist = self.base_dist.expand(
62
            batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]
63
        )
64
        new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
65
        super(Independent, new).__init__(
66
            batch_shape, self.event_shape, validate_args=False
67
        )
68
        new._validate_args = self._validate_args
69
        return new
70

71
    @property
72
    def has_rsample(self):
73
        return self.base_dist.has_rsample
74

75
    @property
76
    def has_enumerate_support(self):
77
        if self.reinterpreted_batch_ndims > 0:
78
            return False
79
        return self.base_dist.has_enumerate_support
80

81
    @constraints.dependent_property
82
    def support(self):
83
        result = self.base_dist.support
84
        if self.reinterpreted_batch_ndims:
85
            result = constraints.independent(result, self.reinterpreted_batch_ndims)
86
        return result
87

88
    @property
89
    def mean(self):
90
        return self.base_dist.mean
91

92
    @property
93
    def mode(self):
94
        return self.base_dist.mode
95

96
    @property
97
    def variance(self):
98
        return self.base_dist.variance
99

100
    def sample(self, sample_shape=torch.Size()):
101
        return self.base_dist.sample(sample_shape)
102

103
    def rsample(self, sample_shape=torch.Size()):
104
        return self.base_dist.rsample(sample_shape)
105

106
    def log_prob(self, value):
107
        log_prob = self.base_dist.log_prob(value)
108
        return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
109

110
    def entropy(self):
111
        entropy = self.base_dist.entropy()
112
        return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
113

114
    def enumerate_support(self, expand=True):
115
        if self.reinterpreted_batch_ndims > 0:
116
            raise NotImplementedError(
117
                "Enumeration over cartesian product is not implemented"
118
            )
119
        return self.base_dist.enumerate_support(expand=expand)
120

121
    def __repr__(self):
122
        return (
123
            self.__class__.__name__
124
            + f"({self.base_dist}, {self.reinterpreted_batch_ndims})"
125
        )
126

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.