pytorch
214 строк · 8.3 Кб
1from typing import Dict2
3import torch4from torch.distributions import Categorical, constraints5from torch.distributions.distribution import Distribution6
7__all__ = ["MixtureSameFamily"]8
9
10class MixtureSameFamily(Distribution):11r"""12The `MixtureSameFamily` distribution implements a (batch of) mixture
13distribution where all component are from different parameterizations of
14the same distribution type. It is parameterized by a `Categorical`
15"selecting distribution" (over `k` component) and a component
16distribution, i.e., a `Distribution` with a rightmost batch shape
17(equal to `[k]`) which indexes each (batch of) component.
18
19Examples::
20
21>>> # xdoctest: +SKIP("undefined vars")
22>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
23>>> # weighted normal distributions
24>>> mix = D.Categorical(torch.ones(5,))
25>>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
26>>> gmm = MixtureSameFamily(mix, comp)
27
28>>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
29>>> # weighted bivariate normal distributions
30>>> mix = D.Categorical(torch.ones(5,))
31>>> comp = D.Independent(D.Normal(
32... torch.randn(5,2), torch.rand(5,2)), 1)
33>>> gmm = MixtureSameFamily(mix, comp)
34
35>>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
36>>> # consisting of 5 random weighted bivariate normal distributions
37>>> mix = D.Categorical(torch.rand(3,5))
38>>> comp = D.Independent(D.Normal(
39... torch.randn(3,5,2), torch.rand(3,5,2)), 1)
40>>> gmm = MixtureSameFamily(mix, comp)
41
42Args:
43mixture_distribution: `torch.distributions.Categorical`-like
44instance. Manages the probability of selecting component.
45The number of categories must match the rightmost batch
46dimension of the `component_distribution`. Must have either
47scalar `batch_shape` or `batch_shape` matching
48`component_distribution.batch_shape[:-1]`
49component_distribution: `torch.distributions.Distribution`-like
50instance. Right-most batch dimension indexes component.
51"""
52arg_constraints: Dict[str, constraints.Constraint] = {}53has_rsample = False54
55def __init__(56self, mixture_distribution, component_distribution, validate_args=None57):58self._mixture_distribution = mixture_distribution59self._component_distribution = component_distribution60
61if not isinstance(self._mixture_distribution, Categorical):62raise ValueError(63" The Mixture distribution needs to be an "64" instance of torch.distributions.Categorical"65)66
67if not isinstance(self._component_distribution, Distribution):68raise ValueError(69"The Component distribution need to be an "70"instance of torch.distributions.Distribution"71)72
73# Check that batch size matches74mdbs = self._mixture_distribution.batch_shape75cdbs = self._component_distribution.batch_shape[:-1]76for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):77if size1 != 1 and size2 != 1 and size1 != size2:78raise ValueError(79f"`mixture_distribution.batch_shape` ({mdbs}) is not "80"compatible with `component_distribution."81f"batch_shape`({cdbs})"82)83
84# Check that the number of mixture component matches85km = self._mixture_distribution.logits.shape[-1]86kc = self._component_distribution.batch_shape[-1]87if km is not None and kc is not None and km != kc:88raise ValueError(89f"`mixture_distribution component` ({km}) does not"90" equal `component_distribution.batch_shape[-1]`"91f" ({kc})"92)93self._num_component = km94
95event_shape = self._component_distribution.event_shape96self._event_ndims = len(event_shape)97super().__init__(98batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args99)100
101def expand(self, batch_shape, _instance=None):102batch_shape = torch.Size(batch_shape)103batch_shape_comp = batch_shape + (self._num_component,)104new = self._get_checked_instance(MixtureSameFamily, _instance)105new._component_distribution = self._component_distribution.expand(106batch_shape_comp
107)108new._mixture_distribution = self._mixture_distribution.expand(batch_shape)109new._num_component = self._num_component110new._event_ndims = self._event_ndims111event_shape = new._component_distribution.event_shape112super(MixtureSameFamily, new).__init__(113batch_shape=batch_shape, event_shape=event_shape, validate_args=False114)115new._validate_args = self._validate_args116return new117
118@constraints.dependent_property119def support(self):120# FIXME this may have the wrong shape when support contains batched121# parameters122return self._component_distribution.support123
124@property125def mixture_distribution(self):126return self._mixture_distribution127
128@property129def component_distribution(self):130return self._component_distribution131
132@property133def mean(self):134probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)135return torch.sum(136probs * self.component_distribution.mean, dim=-1 - self._event_ndims137) # [B, E]138
139@property140def variance(self):141# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])142probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)143mean_cond_var = torch.sum(144probs * self.component_distribution.variance, dim=-1 - self._event_ndims145)146var_cond_mean = torch.sum(147probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0),148dim=-1 - self._event_ndims,149)150return mean_cond_var + var_cond_mean151
152def cdf(self, x):153x = self._pad(x)154cdf_x = self.component_distribution.cdf(x)155mix_prob = self.mixture_distribution.probs156
157return torch.sum(cdf_x * mix_prob, dim=-1)158
159def log_prob(self, x):160if self._validate_args:161self._validate_sample(x)162x = self._pad(x)163log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]164log_mix_prob = torch.log_softmax(165self.mixture_distribution.logits, dim=-1166) # [B, k]167return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]168
169def sample(self, sample_shape=torch.Size()):170with torch.no_grad():171sample_len = len(sample_shape)172batch_len = len(self.batch_shape)173gather_dim = sample_len + batch_len174es = self.event_shape175
176# mixture samples [n, B]177mix_sample = self.mixture_distribution.sample(sample_shape)178mix_shape = mix_sample.shape179
180# component samples [n, B, k, E]181comp_samples = self.component_distribution.sample(sample_shape)182
183# Gather along the k dimension184mix_sample_r = mix_sample.reshape(185mix_shape + torch.Size([1] * (len(es) + 1))186)187mix_sample_r = mix_sample_r.repeat(188torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es189)190
191samples = torch.gather(comp_samples, gather_dim, mix_sample_r)192return samples.squeeze(gather_dim)193
194def _pad(self, x):195return x.unsqueeze(-1 - self._event_ndims)196
197def _pad_mixture_dimensions(self, x):198dist_batch_ndims = len(self.batch_shape)199cat_batch_ndims = len(self.mixture_distribution.batch_shape)200pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims201xs = x.shape202x = x.reshape(203xs[:-1]204+ torch.Size(pad_ndims * [1])205+ xs[-1:]206+ torch.Size(self._event_ndims * [1])207)208return x209
210def __repr__(self):211args_string = (212f"\n {self.mixture_distribution},\n {self.component_distribution}"213)214return "MixtureSameFamily" + "(" + args_string + ")"215