pytorch

Форк
0
/
mixture_same_family.py 
214 строк · 8.3 Кб
1
from typing import Dict
2

3
import torch
4
from torch.distributions import Categorical, constraints
5
from torch.distributions.distribution import Distribution
6

7
__all__ = ["MixtureSameFamily"]
8

9

10
class MixtureSameFamily(Distribution):
11
    r"""
12
    The `MixtureSameFamily` distribution implements a (batch of) mixture
13
    distribution where all component are from different parameterizations of
14
    the same distribution type. It is parameterized by a `Categorical`
15
    "selecting distribution" (over `k` component) and a component
16
    distribution, i.e., a `Distribution` with a rightmost batch shape
17
    (equal to `[k]`) which indexes each (batch of) component.
18

19
    Examples::
20

21
        >>> # xdoctest: +SKIP("undefined vars")
22
        >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
23
        >>> # weighted normal distributions
24
        >>> mix = D.Categorical(torch.ones(5,))
25
        >>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
26
        >>> gmm = MixtureSameFamily(mix, comp)
27

28
        >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
29
        >>> # weighted bivariate normal distributions
30
        >>> mix = D.Categorical(torch.ones(5,))
31
        >>> comp = D.Independent(D.Normal(
32
        ...          torch.randn(5,2), torch.rand(5,2)), 1)
33
        >>> gmm = MixtureSameFamily(mix, comp)
34

35
        >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
36
        >>> # consisting of 5 random weighted bivariate normal distributions
37
        >>> mix = D.Categorical(torch.rand(3,5))
38
        >>> comp = D.Independent(D.Normal(
39
        ...         torch.randn(3,5,2), torch.rand(3,5,2)), 1)
40
        >>> gmm = MixtureSameFamily(mix, comp)
41

42
    Args:
43
        mixture_distribution: `torch.distributions.Categorical`-like
44
            instance. Manages the probability of selecting component.
45
            The number of categories must match the rightmost batch
46
            dimension of the `component_distribution`. Must have either
47
            scalar `batch_shape` or `batch_shape` matching
48
            `component_distribution.batch_shape[:-1]`
49
        component_distribution: `torch.distributions.Distribution`-like
50
            instance. Right-most batch dimension indexes component.
51
    """
52
    arg_constraints: Dict[str, constraints.Constraint] = {}
53
    has_rsample = False
54

55
    def __init__(
56
        self, mixture_distribution, component_distribution, validate_args=None
57
    ):
58
        self._mixture_distribution = mixture_distribution
59
        self._component_distribution = component_distribution
60

61
        if not isinstance(self._mixture_distribution, Categorical):
62
            raise ValueError(
63
                " The Mixture distribution needs to be an "
64
                " instance of torch.distributions.Categorical"
65
            )
66

67
        if not isinstance(self._component_distribution, Distribution):
68
            raise ValueError(
69
                "The Component distribution need to be an "
70
                "instance of torch.distributions.Distribution"
71
            )
72

73
        # Check that batch size matches
74
        mdbs = self._mixture_distribution.batch_shape
75
        cdbs = self._component_distribution.batch_shape[:-1]
76
        for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
77
            if size1 != 1 and size2 != 1 and size1 != size2:
78
                raise ValueError(
79
                    f"`mixture_distribution.batch_shape` ({mdbs}) is not "
80
                    "compatible with `component_distribution."
81
                    f"batch_shape`({cdbs})"
82
                )
83

84
        # Check that the number of mixture component matches
85
        km = self._mixture_distribution.logits.shape[-1]
86
        kc = self._component_distribution.batch_shape[-1]
87
        if km is not None and kc is not None and km != kc:
88
            raise ValueError(
89
                f"`mixture_distribution component` ({km}) does not"
90
                " equal `component_distribution.batch_shape[-1]`"
91
                f" ({kc})"
92
            )
93
        self._num_component = km
94

95
        event_shape = self._component_distribution.event_shape
96
        self._event_ndims = len(event_shape)
97
        super().__init__(
98
            batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args
99
        )
100

101
    def expand(self, batch_shape, _instance=None):
102
        batch_shape = torch.Size(batch_shape)
103
        batch_shape_comp = batch_shape + (self._num_component,)
104
        new = self._get_checked_instance(MixtureSameFamily, _instance)
105
        new._component_distribution = self._component_distribution.expand(
106
            batch_shape_comp
107
        )
108
        new._mixture_distribution = self._mixture_distribution.expand(batch_shape)
109
        new._num_component = self._num_component
110
        new._event_ndims = self._event_ndims
111
        event_shape = new._component_distribution.event_shape
112
        super(MixtureSameFamily, new).__init__(
113
            batch_shape=batch_shape, event_shape=event_shape, validate_args=False
114
        )
115
        new._validate_args = self._validate_args
116
        return new
117

118
    @constraints.dependent_property
119
    def support(self):
120
        # FIXME this may have the wrong shape when support contains batched
121
        # parameters
122
        return self._component_distribution.support
123

124
    @property
125
    def mixture_distribution(self):
126
        return self._mixture_distribution
127

128
    @property
129
    def component_distribution(self):
130
        return self._component_distribution
131

132
    @property
133
    def mean(self):
134
        probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
135
        return torch.sum(
136
            probs * self.component_distribution.mean, dim=-1 - self._event_ndims
137
        )  # [B, E]
138

139
    @property
140
    def variance(self):
141
        # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
142
        probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
143
        mean_cond_var = torch.sum(
144
            probs * self.component_distribution.variance, dim=-1 - self._event_ndims
145
        )
146
        var_cond_mean = torch.sum(
147
            probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0),
148
            dim=-1 - self._event_ndims,
149
        )
150
        return mean_cond_var + var_cond_mean
151

152
    def cdf(self, x):
153
        x = self._pad(x)
154
        cdf_x = self.component_distribution.cdf(x)
155
        mix_prob = self.mixture_distribution.probs
156

157
        return torch.sum(cdf_x * mix_prob, dim=-1)
158

159
    def log_prob(self, x):
160
        if self._validate_args:
161
            self._validate_sample(x)
162
        x = self._pad(x)
163
        log_prob_x = self.component_distribution.log_prob(x)  # [S, B, k]
164
        log_mix_prob = torch.log_softmax(
165
            self.mixture_distribution.logits, dim=-1
166
        )  # [B, k]
167
        return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1)  # [S, B]
168

169
    def sample(self, sample_shape=torch.Size()):
170
        with torch.no_grad():
171
            sample_len = len(sample_shape)
172
            batch_len = len(self.batch_shape)
173
            gather_dim = sample_len + batch_len
174
            es = self.event_shape
175

176
            # mixture samples [n, B]
177
            mix_sample = self.mixture_distribution.sample(sample_shape)
178
            mix_shape = mix_sample.shape
179

180
            # component samples [n, B, k, E]
181
            comp_samples = self.component_distribution.sample(sample_shape)
182

183
            # Gather along the k dimension
184
            mix_sample_r = mix_sample.reshape(
185
                mix_shape + torch.Size([1] * (len(es) + 1))
186
            )
187
            mix_sample_r = mix_sample_r.repeat(
188
                torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es
189
            )
190

191
            samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
192
            return samples.squeeze(gather_dim)
193

194
    def _pad(self, x):
195
        return x.unsqueeze(-1 - self._event_ndims)
196

197
    def _pad_mixture_dimensions(self, x):
198
        dist_batch_ndims = len(self.batch_shape)
199
        cat_batch_ndims = len(self.mixture_distribution.batch_shape)
200
        pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
201
        xs = x.shape
202
        x = x.reshape(
203
            xs[:-1]
204
            + torch.Size(pad_ndims * [1])
205
            + xs[-1:]
206
            + torch.Size(self._event_ndims * [1])
207
        )
208
        return x
209

210
    def __repr__(self):
211
        args_string = (
212
            f"\n  {self.mixture_distribution},\n  {self.component_distribution}"
213
        )
214
        return "MixtureSameFamily" + "(" + args_string + ")"
215

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.