pytorch

Форк
0
971 строка · 31.0 Кб
1
import math
2
import warnings
3
from functools import total_ordering
4
from typing import Callable, Dict, Tuple, Type
5

6
import torch
7
from torch import inf
8

9
from .bernoulli import Bernoulli
10
from .beta import Beta
11
from .binomial import Binomial
12
from .categorical import Categorical
13
from .cauchy import Cauchy
14
from .continuous_bernoulli import ContinuousBernoulli
15
from .dirichlet import Dirichlet
16
from .distribution import Distribution
17
from .exp_family import ExponentialFamily
18
from .exponential import Exponential
19
from .gamma import Gamma
20
from .geometric import Geometric
21
from .gumbel import Gumbel
22
from .half_normal import HalfNormal
23
from .independent import Independent
24
from .laplace import Laplace
25
from .lowrank_multivariate_normal import (
26
    _batch_lowrank_logdet,
27
    _batch_lowrank_mahalanobis,
28
    LowRankMultivariateNormal,
29
)
30
from .multivariate_normal import _batch_mahalanobis, MultivariateNormal
31
from .normal import Normal
32
from .one_hot_categorical import OneHotCategorical
33
from .pareto import Pareto
34
from .poisson import Poisson
35
from .transformed_distribution import TransformedDistribution
36
from .uniform import Uniform
37
from .utils import _sum_rightmost, euler_constant as _euler_gamma
38

39
_KL_REGISTRY: Dict[
40
    Tuple[Type, Type], Callable
41
] = {}  # Source of truth mapping a few general (type, type) pairs to functions.
42
_KL_MEMOIZE: Dict[
43
    Tuple[Type, Type], Callable
44
] = {}  # Memoized version mapping many specific (type, type) pairs to functions.
45

46
__all__ = ["register_kl", "kl_divergence"]
47

48

49
def register_kl(type_p, type_q):
50
    """
51
    Decorator to register a pairwise function with :meth:`kl_divergence`.
52
    Usage::
53

54
        @register_kl(Normal, Normal)
55
        def kl_normal_normal(p, q):
56
            # insert implementation here
57

58
    Lookup returns the most specific (type,type) match ordered by subclass. If
59
    the match is ambiguous, a `RuntimeWarning` is raised. For example to
60
    resolve the ambiguous situation::
61

62
        @register_kl(BaseP, DerivedQ)
63
        def kl_version1(p, q): ...
64
        @register_kl(DerivedP, BaseQ)
65
        def kl_version2(p, q): ...
66

67
    you should register a third most-specific implementation, e.g.::
68

69
        register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.
70

71
    Args:
72
        type_p (type): A subclass of :class:`~torch.distributions.Distribution`.
73
        type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
74
    """
75
    if not isinstance(type_p, type) and issubclass(type_p, Distribution):
76
        raise TypeError(
77
            f"Expected type_p to be a Distribution subclass but got {type_p}"
78
        )
79
    if not isinstance(type_q, type) and issubclass(type_q, Distribution):
80
        raise TypeError(
81
            f"Expected type_q to be a Distribution subclass but got {type_q}"
82
        )
83

84
    def decorator(fun):
85
        _KL_REGISTRY[type_p, type_q] = fun
86
        _KL_MEMOIZE.clear()  # reset since lookup order may have changed
87
        return fun
88

89
    return decorator
90

91

92
@total_ordering
93
class _Match:
94
    __slots__ = ["types"]
95

96
    def __init__(self, *types):
97
        self.types = types
98

99
    def __eq__(self, other):
100
        return self.types == other.types
101

102
    def __le__(self, other):
103
        for x, y in zip(self.types, other.types):
104
            if not issubclass(x, y):
105
                return False
106
            if x is not y:
107
                break
108
        return True
109

110

111
def _dispatch_kl(type_p, type_q):
112
    """
113
    Find the most specific approximate match, assuming single inheritance.
114
    """
115
    matches = [
116
        (super_p, super_q)
117
        for super_p, super_q in _KL_REGISTRY
118
        if issubclass(type_p, super_p) and issubclass(type_q, super_q)
119
    ]
120
    if not matches:
121
        return NotImplemented
122
    # Check that the left- and right- lexicographic orders agree.
123
    # mypy isn't smart enough to know that _Match implements __lt__
124
    # see: https://github.com/python/typing/issues/760#issuecomment-710670503
125
    left_p, left_q = min(_Match(*m) for m in matches).types  # type: ignore[type-var]
126
    right_q, right_p = min(_Match(*reversed(m)) for m in matches).types  # type: ignore[type-var]
127
    left_fun = _KL_REGISTRY[left_p, left_q]
128
    right_fun = _KL_REGISTRY[right_p, right_q]
129
    if left_fun is not right_fun:
130
        warnings.warn(
131
            "Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format(
132
                type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__
133
            ),
134
            RuntimeWarning,
135
        )
136
    return left_fun
137

138

139
def _infinite_like(tensor):
140
    """
141
    Helper function for obtaining infinite KL Divergence throughout
142
    """
143
    return torch.full_like(tensor, inf)
144

145

146
def _x_log_x(tensor):
147
    """
148
    Utility function for calculating x log x
149
    """
150
    return tensor * tensor.log()
151

152

153
def _batch_trace_XXT(bmat):
154
    """
155
    Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions
156
    """
157
    n = bmat.size(-1)
158
    m = bmat.size(-2)
159
    flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
160
    return flat_trace.reshape(bmat.shape[:-2])
161

162

163
def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor:
164
    r"""
165
    Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
166

167
    .. math::
168

169
        KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
170

171
    Args:
172
        p (Distribution): A :class:`~torch.distributions.Distribution` object.
173
        q (Distribution): A :class:`~torch.distributions.Distribution` object.
174

175
    Returns:
176
        Tensor: A batch of KL divergences of shape `batch_shape`.
177

178
    Raises:
179
        NotImplementedError: If the distribution types have not been registered via
180
            :meth:`register_kl`.
181
    """
182
    try:
183
        fun = _KL_MEMOIZE[type(p), type(q)]
184
    except KeyError:
185
        fun = _dispatch_kl(type(p), type(q))
186
        _KL_MEMOIZE[type(p), type(q)] = fun
187
    if fun is NotImplemented:
188
        raise NotImplementedError(
189
            f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}"
190
        )
191
    return fun(p, q)
192

193

194
################################################################################
195
# KL Divergence Implementations
196
################################################################################
197

198
# Same distributions
199

200

201
@register_kl(Bernoulli, Bernoulli)
202
def _kl_bernoulli_bernoulli(p, q):
203
    t1 = p.probs * (
204
        torch.nn.functional.softplus(-q.logits)
205
        - torch.nn.functional.softplus(-p.logits)
206
    )
207
    t1[q.probs == 0] = inf
208
    t1[p.probs == 0] = 0
209
    t2 = (1 - p.probs) * (
210
        torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits)
211
    )
212
    t2[q.probs == 1] = inf
213
    t2[p.probs == 1] = 0
214
    return t1 + t2
215

216

217
@register_kl(Beta, Beta)
218
def _kl_beta_beta(p, q):
219
    sum_params_p = p.concentration1 + p.concentration0
220
    sum_params_q = q.concentration1 + q.concentration0
221
    t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
222
    t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
223
    t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
224
    t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
225
    t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
226
    return t1 - t2 + t3 + t4 + t5
227

228

229
@register_kl(Binomial, Binomial)
230
def _kl_binomial_binomial(p, q):
231
    # from https://math.stackexchange.com/questions/2214993/
232
    # kullback-leibler-divergence-for-binomial-distributions-p-and-q
233
    if (p.total_count < q.total_count).any():
234
        raise NotImplementedError(
235
            "KL between Binomials where q.total_count > p.total_count is not implemented"
236
        )
237
    kl = p.total_count * (
238
        p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()
239
    )
240
    inf_idxs = p.total_count > q.total_count
241
    kl[inf_idxs] = _infinite_like(kl[inf_idxs])
242
    return kl
243

244

245
@register_kl(Categorical, Categorical)
246
def _kl_categorical_categorical(p, q):
247
    t = p.probs * (p.logits - q.logits)
248
    t[(q.probs == 0).expand_as(t)] = inf
249
    t[(p.probs == 0).expand_as(t)] = 0
250
    return t.sum(-1)
251

252

253
@register_kl(ContinuousBernoulli, ContinuousBernoulli)
254
def _kl_continuous_bernoulli_continuous_bernoulli(p, q):
255
    t1 = p.mean * (p.logits - q.logits)
256
    t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs)
257
    t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs)
258
    return t1 + t2 + t3
259

260

261
@register_kl(Dirichlet, Dirichlet)
262
def _kl_dirichlet_dirichlet(p, q):
263
    # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
264
    sum_p_concentration = p.concentration.sum(-1)
265
    sum_q_concentration = q.concentration.sum(-1)
266
    t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
267
    t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
268
    t3 = p.concentration - q.concentration
269
    t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
270
    return t1 - t2 + (t3 * t4).sum(-1)
271

272

273
@register_kl(Exponential, Exponential)
274
def _kl_exponential_exponential(p, q):
275
    rate_ratio = q.rate / p.rate
276
    t1 = -rate_ratio.log()
277
    return t1 + rate_ratio - 1
278

279

280
@register_kl(ExponentialFamily, ExponentialFamily)
281
def _kl_expfamily_expfamily(p, q):
282
    if not type(p) == type(q):
283
        raise NotImplementedError(
284
            "The cross KL-divergence between different exponential families cannot \
285
                            be computed using Bregman divergences"
286
        )
287
    p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
288
    q_nparams = q._natural_params
289
    lg_normal = p._log_normalizer(*p_nparams)
290
    gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
291
    result = q._log_normalizer(*q_nparams) - lg_normal
292
    for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
293
        term = (qnp - pnp) * g
294
        result -= _sum_rightmost(term, len(q.event_shape))
295
    return result
296

297

298
@register_kl(Gamma, Gamma)
299
def _kl_gamma_gamma(p, q):
300
    t1 = q.concentration * (p.rate / q.rate).log()
301
    t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
302
    t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
303
    t4 = (q.rate - p.rate) * (p.concentration / p.rate)
304
    return t1 + t2 + t3 + t4
305

306

307
@register_kl(Gumbel, Gumbel)
308
def _kl_gumbel_gumbel(p, q):
309
    ct1 = p.scale / q.scale
310
    ct2 = q.loc / q.scale
311
    ct3 = p.loc / q.scale
312
    t1 = -ct1.log() - ct2 + ct3
313
    t2 = ct1 * _euler_gamma
314
    t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
315
    return t1 + t2 + t3 - (1 + _euler_gamma)
316

317

318
@register_kl(Geometric, Geometric)
319
def _kl_geometric_geometric(p, q):
320
    return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
321

322

323
@register_kl(HalfNormal, HalfNormal)
324
def _kl_halfnormal_halfnormal(p, q):
325
    return _kl_normal_normal(p.base_dist, q.base_dist)
326

327

328
@register_kl(Laplace, Laplace)
329
def _kl_laplace_laplace(p, q):
330
    # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
331
    scale_ratio = p.scale / q.scale
332
    loc_abs_diff = (p.loc - q.loc).abs()
333
    t1 = -scale_ratio.log()
334
    t2 = loc_abs_diff / q.scale
335
    t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
336
    return t1 + t2 + t3 - 1
337

338

339
@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
340
def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
341
    if p.event_shape != q.event_shape:
342
        raise ValueError(
343
            "KL-divergence between two Low Rank Multivariate Normals with\
344
                          different event shapes cannot be computed"
345
        )
346

347
    term1 = _batch_lowrank_logdet(
348
        q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
349
    ) - _batch_lowrank_logdet(
350
        p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
351
    )
352
    term3 = _batch_lowrank_mahalanobis(
353
        q._unbroadcasted_cov_factor,
354
        q._unbroadcasted_cov_diag,
355
        q.loc - p.loc,
356
        q._capacitance_tril,
357
    )
358
    # Expands term2 according to
359
    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
360
    #                  = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
361
    qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
362
    A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
363
    term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
364
    term22 = _batch_trace_XXT(
365
        p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
366
    )
367
    term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
368
    term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
369
    term2 = term21 + term22 - term23 - term24
370
    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
371

372

373
@register_kl(MultivariateNormal, LowRankMultivariateNormal)
374
def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
375
    if p.event_shape != q.event_shape:
376
        raise ValueError(
377
            "KL-divergence between two (Low Rank) Multivariate Normals with\
378
                          different event shapes cannot be computed"
379
        )
380

381
    term1 = _batch_lowrank_logdet(
382
        q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
383
    ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
384
    term3 = _batch_lowrank_mahalanobis(
385
        q._unbroadcasted_cov_factor,
386
        q._unbroadcasted_cov_diag,
387
        q.loc - p.loc,
388
        q._capacitance_tril,
389
    )
390
    # Expands term2 according to
391
    # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
392
    #                  = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
393
    qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
394
    A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
395
    term21 = _batch_trace_XXT(
396
        p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
397
    )
398
    term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
399
    term2 = term21 - term22
400
    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
401

402

403
@register_kl(LowRankMultivariateNormal, MultivariateNormal)
404
def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
405
    if p.event_shape != q.event_shape:
406
        raise ValueError(
407
            "KL-divergence between two (Low Rank) Multivariate Normals with\
408
                          different event shapes cannot be computed"
409
        )
410

411
    term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
412
        -1
413
    ) - _batch_lowrank_logdet(
414
        p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
415
    )
416
    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
417
    # Expands term2 according to
418
    # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
419
    combined_batch_shape = torch._C._infer_size(
420
        q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2]
421
    )
422
    n = p.event_shape[0]
423
    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
424
    p_cov_factor = p._unbroadcasted_cov_factor.expand(
425
        combined_batch_shape + (n, p.cov_factor.size(-1))
426
    )
427
    p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand(
428
        combined_batch_shape + (n, n)
429
    )
430
    term21 = _batch_trace_XXT(
431
        torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False)
432
    )
433
    term22 = _batch_trace_XXT(
434
        torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False)
435
    )
436
    term2 = term21 + term22
437
    return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
438

439

440
@register_kl(MultivariateNormal, MultivariateNormal)
441
def _kl_multivariatenormal_multivariatenormal(p, q):
442
    # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
443
    if p.event_shape != q.event_shape:
444
        raise ValueError(
445
            "KL-divergence between two Multivariate Normals with\
446
                          different event shapes cannot be computed"
447
        )
448

449
    half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
450
        -1
451
    ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
452
    combined_batch_shape = torch._C._infer_size(
453
        q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2]
454
    )
455
    n = p.event_shape[0]
456
    q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
457
    p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
458
    term2 = _batch_trace_XXT(
459
        torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False)
460
    )
461
    term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
462
    return half_term1 + 0.5 * (term2 + term3 - n)
463

464

465
@register_kl(Normal, Normal)
466
def _kl_normal_normal(p, q):
467
    var_ratio = (p.scale / q.scale).pow(2)
468
    t1 = ((p.loc - q.loc) / q.scale).pow(2)
469
    return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
470

471

472
@register_kl(OneHotCategorical, OneHotCategorical)
473
def _kl_onehotcategorical_onehotcategorical(p, q):
474
    return _kl_categorical_categorical(p._categorical, q._categorical)
475

476

477
@register_kl(Pareto, Pareto)
478
def _kl_pareto_pareto(p, q):
479
    # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
480
    scale_ratio = p.scale / q.scale
481
    alpha_ratio = q.alpha / p.alpha
482
    t1 = q.alpha * scale_ratio.log()
483
    t2 = -alpha_ratio.log()
484
    result = t1 + t2 + alpha_ratio - 1
485
    result[p.support.lower_bound < q.support.lower_bound] = inf
486
    return result
487

488

489
@register_kl(Poisson, Poisson)
490
def _kl_poisson_poisson(p, q):
491
    return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
492

493

494
@register_kl(TransformedDistribution, TransformedDistribution)
495
def _kl_transformed_transformed(p, q):
496
    if p.transforms != q.transforms:
497
        raise NotImplementedError
498
    if p.event_shape != q.event_shape:
499
        raise NotImplementedError
500
    return kl_divergence(p.base_dist, q.base_dist)
501

502

503
@register_kl(Uniform, Uniform)
504
def _kl_uniform_uniform(p, q):
505
    result = ((q.high - q.low) / (p.high - p.low)).log()
506
    result[(q.low > p.low) | (q.high < p.high)] = inf
507
    return result
508

509

510
# Different distributions
511
@register_kl(Bernoulli, Poisson)
512
def _kl_bernoulli_poisson(p, q):
513
    return -p.entropy() - (p.probs * q.rate.log() - q.rate)
514

515

516
@register_kl(Beta, ContinuousBernoulli)
517
def _kl_beta_continuous_bernoulli(p, q):
518
    return (
519
        -p.entropy()
520
        - p.mean * q.logits
521
        - torch.log1p(-q.probs)
522
        - q._cont_bern_log_norm()
523
    )
524

525

526
@register_kl(Beta, Pareto)
527
def _kl_beta_infinity(p, q):
528
    return _infinite_like(p.concentration1)
529

530

531
@register_kl(Beta, Exponential)
532
def _kl_beta_exponential(p, q):
533
    return (
534
        -p.entropy()
535
        - q.rate.log()
536
        + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
537
    )
538

539

540
@register_kl(Beta, Gamma)
541
def _kl_beta_gamma(p, q):
542
    t1 = -p.entropy()
543
    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
544
    t3 = (q.concentration - 1) * (
545
        p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma()
546
    )
547
    t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
548
    return t1 + t2 - t3 + t4
549

550

551
# TODO: Add Beta-Laplace KL Divergence
552

553

554
@register_kl(Beta, Normal)
555
def _kl_beta_normal(p, q):
556
    E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
557
    var_normal = q.scale.pow(2)
558
    t1 = -p.entropy()
559
    t2 = 0.5 * (var_normal * 2 * math.pi).log()
560
    t3 = (
561
        E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1)
562
        + E_beta.pow(2)
563
    ) * 0.5
564
    t4 = q.loc * E_beta
565
    t5 = q.loc.pow(2) * 0.5
566
    return t1 + t2 + (t3 - t4 + t5) / var_normal
567

568

569
@register_kl(Beta, Uniform)
570
def _kl_beta_uniform(p, q):
571
    result = -p.entropy() + (q.high - q.low).log()
572
    result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
573
    return result
574

575

576
# Note that the KL between a ContinuousBernoulli and Beta has no closed form
577

578

579
@register_kl(ContinuousBernoulli, Pareto)
580
def _kl_continuous_bernoulli_infinity(p, q):
581
    return _infinite_like(p.probs)
582

583

584
@register_kl(ContinuousBernoulli, Exponential)
585
def _kl_continuous_bernoulli_exponential(p, q):
586
    return -p.entropy() - torch.log(q.rate) + q.rate * p.mean
587

588

589
# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
590
# TODO: Add ContinuousBernoulli-Laplace KL Divergence
591

592

593
@register_kl(ContinuousBernoulli, Normal)
594
def _kl_continuous_bernoulli_normal(p, q):
595
    t1 = -p.entropy()
596
    t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log(
597
        q.scale
598
    )
599
    t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / (
600
        2.0 * torch.square(q.scale)
601
    )
602
    return t1 + t2 + t3
603

604

605
@register_kl(ContinuousBernoulli, Uniform)
606
def _kl_continuous_bernoulli_uniform(p, q):
607
    result = -p.entropy() + (q.high - q.low).log()
608
    return torch.where(
609
        torch.max(
610
            torch.ge(q.low, p.support.lower_bound),
611
            torch.le(q.high, p.support.upper_bound),
612
        ),
613
        torch.ones_like(result) * inf,
614
        result,
615
    )
616

617

618
@register_kl(Exponential, Beta)
619
@register_kl(Exponential, ContinuousBernoulli)
620
@register_kl(Exponential, Pareto)
621
@register_kl(Exponential, Uniform)
622
def _kl_exponential_infinity(p, q):
623
    return _infinite_like(p.rate)
624

625

626
@register_kl(Exponential, Gamma)
627
def _kl_exponential_gamma(p, q):
628
    ratio = q.rate / p.rate
629
    t1 = -q.concentration * torch.log(ratio)
630
    return (
631
        t1
632
        + ratio
633
        + q.concentration.lgamma()
634
        + q.concentration * _euler_gamma
635
        - (1 + _euler_gamma)
636
    )
637

638

639
@register_kl(Exponential, Gumbel)
640
def _kl_exponential_gumbel(p, q):
641
    scale_rate_prod = p.rate * q.scale
642
    loc_scale_ratio = q.loc / q.scale
643
    t1 = scale_rate_prod.log() - 1
644
    t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
645
    t3 = scale_rate_prod.reciprocal()
646
    return t1 - loc_scale_ratio + t2 + t3
647

648

649
# TODO: Add Exponential-Laplace KL Divergence
650

651

652
@register_kl(Exponential, Normal)
653
def _kl_exponential_normal(p, q):
654
    var_normal = q.scale.pow(2)
655
    rate_sqr = p.rate.pow(2)
656
    t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
657
    t2 = rate_sqr.reciprocal()
658
    t3 = q.loc / p.rate
659
    t4 = q.loc.pow(2) * 0.5
660
    return t1 - 1 + (t2 - t3 + t4) / var_normal
661

662

663
@register_kl(Gamma, Beta)
664
@register_kl(Gamma, ContinuousBernoulli)
665
@register_kl(Gamma, Pareto)
666
@register_kl(Gamma, Uniform)
667
def _kl_gamma_infinity(p, q):
668
    return _infinite_like(p.concentration)
669

670

671
@register_kl(Gamma, Exponential)
672
def _kl_gamma_exponential(p, q):
673
    return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
674

675

676
@register_kl(Gamma, Gumbel)
677
def _kl_gamma_gumbel(p, q):
678
    beta_scale_prod = p.rate * q.scale
679
    loc_scale_ratio = q.loc / q.scale
680
    t1 = (
681
        (p.concentration - 1) * p.concentration.digamma()
682
        - p.concentration.lgamma()
683
        - p.concentration
684
    )
685
    t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
686
    t3 = (
687
        torch.exp(loc_scale_ratio)
688
        * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration)
689
        - loc_scale_ratio
690
    )
691
    return t1 + t2 + t3
692

693

694
# TODO: Add Gamma-Laplace KL Divergence
695

696

697
@register_kl(Gamma, Normal)
698
def _kl_gamma_normal(p, q):
699
    var_normal = q.scale.pow(2)
700
    beta_sqr = p.rate.pow(2)
701
    t1 = (
702
        0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi)
703
        - p.concentration
704
        - p.concentration.lgamma()
705
    )
706
    t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
707
    t3 = q.loc * p.concentration / p.rate
708
    t4 = 0.5 * q.loc.pow(2)
709
    return (
710
        t1
711
        + (p.concentration - 1) * p.concentration.digamma()
712
        + (t2 - t3 + t4) / var_normal
713
    )
714

715

716
@register_kl(Gumbel, Beta)
717
@register_kl(Gumbel, ContinuousBernoulli)
718
@register_kl(Gumbel, Exponential)
719
@register_kl(Gumbel, Gamma)
720
@register_kl(Gumbel, Pareto)
721
@register_kl(Gumbel, Uniform)
722
def _kl_gumbel_infinity(p, q):
723
    return _infinite_like(p.loc)
724

725

726
# TODO: Add Gumbel-Laplace KL Divergence
727

728

729
@register_kl(Gumbel, Normal)
730
def _kl_gumbel_normal(p, q):
731
    param_ratio = p.scale / q.scale
732
    t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
733
    t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
734
    t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
735
    return -t1 + t2 + t3 - (_euler_gamma + 1)
736

737

738
@register_kl(Laplace, Beta)
739
@register_kl(Laplace, ContinuousBernoulli)
740
@register_kl(Laplace, Exponential)
741
@register_kl(Laplace, Gamma)
742
@register_kl(Laplace, Pareto)
743
@register_kl(Laplace, Uniform)
744
def _kl_laplace_infinity(p, q):
745
    return _infinite_like(p.loc)
746

747

748
@register_kl(Laplace, Normal)
749
def _kl_laplace_normal(p, q):
750
    var_normal = q.scale.pow(2)
751
    scale_sqr_var_ratio = p.scale.pow(2) / var_normal
752
    t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
753
    t2 = 0.5 * p.loc.pow(2)
754
    t3 = p.loc * q.loc
755
    t4 = 0.5 * q.loc.pow(2)
756
    return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
757

758

759
@register_kl(Normal, Beta)
760
@register_kl(Normal, ContinuousBernoulli)
761
@register_kl(Normal, Exponential)
762
@register_kl(Normal, Gamma)
763
@register_kl(Normal, Pareto)
764
@register_kl(Normal, Uniform)
765
def _kl_normal_infinity(p, q):
766
    return _infinite_like(p.loc)
767

768

769
@register_kl(Normal, Gumbel)
770
def _kl_normal_gumbel(p, q):
771
    mean_scale_ratio = p.loc / q.scale
772
    var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
773
    loc_scale_ratio = q.loc / q.scale
774
    t1 = var_scale_sqr_ratio.log() * 0.5
775
    t2 = mean_scale_ratio - loc_scale_ratio
776
    t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
777
    return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
778

779

780
@register_kl(Normal, Laplace)
781
def _kl_normal_laplace(p, q):
782
    loc_diff = p.loc - q.loc
783
    scale_ratio = p.scale / q.scale
784
    loc_diff_scale_ratio = loc_diff / p.scale
785
    t1 = torch.log(scale_ratio)
786
    t2 = (
787
        math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2))
788
    )
789
    t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio)
790
    return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi)))
791

792

793
@register_kl(Pareto, Beta)
794
@register_kl(Pareto, ContinuousBernoulli)
795
@register_kl(Pareto, Uniform)
796
def _kl_pareto_infinity(p, q):
797
    return _infinite_like(p.scale)
798

799

800
@register_kl(Pareto, Exponential)
801
def _kl_pareto_exponential(p, q):
802
    scale_rate_prod = p.scale * q.rate
803
    t1 = (p.alpha / scale_rate_prod).log()
804
    t2 = p.alpha.reciprocal()
805
    t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
806
    result = t1 - t2 + t3 - 1
807
    result[p.alpha <= 1] = inf
808
    return result
809

810

811
@register_kl(Pareto, Gamma)
812
def _kl_pareto_gamma(p, q):
813
    common_term = p.scale.log() + p.alpha.reciprocal()
814
    t1 = p.alpha.log() - common_term
815
    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
816
    t3 = (1 - q.concentration) * common_term
817
    t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
818
    result = t1 + t2 + t3 + t4 - 1
819
    result[p.alpha <= 1] = inf
820
    return result
821

822

823
# TODO: Add Pareto-Laplace KL Divergence
824

825

826
@register_kl(Pareto, Normal)
827
def _kl_pareto_normal(p, q):
828
    var_normal = 2 * q.scale.pow(2)
829
    common_term = p.scale / (p.alpha - 1)
830
    t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
831
    t2 = p.alpha.reciprocal()
832
    t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
833
    t4 = (p.alpha * common_term - q.loc).pow(2)
834
    result = t1 - t2 + (t3 + t4) / var_normal - 1
835
    result[p.alpha <= 2] = inf
836
    return result
837

838

839
@register_kl(Poisson, Bernoulli)
840
@register_kl(Poisson, Binomial)
841
def _kl_poisson_infinity(p, q):
842
    return _infinite_like(p.rate)
843

844

845
@register_kl(Uniform, Beta)
846
def _kl_uniform_beta(p, q):
847
    common_term = p.high - p.low
848
    t1 = torch.log(common_term)
849
    t2 = (
850
        (q.concentration1 - 1)
851
        * (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
852
        / common_term
853
    )
854
    t3 = (
855
        (q.concentration0 - 1)
856
        * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term)
857
        / common_term
858
    )
859
    t4 = (
860
        q.concentration1.lgamma()
861
        + q.concentration0.lgamma()
862
        - (q.concentration1 + q.concentration0).lgamma()
863
    )
864
    result = t3 + t4 - t1 - t2
865
    result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
866
    return result
867

868

869
@register_kl(Uniform, ContinuousBernoulli)
870
def _kl_uniform_continuous_bernoulli(p, q):
871
    result = (
872
        -p.entropy()
873
        - p.mean * q.logits
874
        - torch.log1p(-q.probs)
875
        - q._cont_bern_log_norm()
876
    )
877
    return torch.where(
878
        torch.max(
879
            torch.ge(p.high, q.support.upper_bound),
880
            torch.le(p.low, q.support.lower_bound),
881
        ),
882
        torch.ones_like(result) * inf,
883
        result,
884
    )
885

886

887
@register_kl(Uniform, Exponential)
888
def _kl_uniform_exponetial(p, q):
889
    result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
890
    result[p.low < q.support.lower_bound] = inf
891
    return result
892

893

894
@register_kl(Uniform, Gamma)
895
def _kl_uniform_gamma(p, q):
896
    common_term = p.high - p.low
897
    t1 = common_term.log()
898
    t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
899
    t3 = (
900
        (1 - q.concentration)
901
        * (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
902
        / common_term
903
    )
904
    t4 = q.rate * (p.high + p.low) / 2
905
    result = -t1 + t2 + t3 + t4
906
    result[p.low < q.support.lower_bound] = inf
907
    return result
908

909

910
@register_kl(Uniform, Gumbel)
911
def _kl_uniform_gumbel(p, q):
912
    common_term = q.scale / (p.high - p.low)
913
    high_loc_diff = (p.high - q.loc) / q.scale
914
    low_loc_diff = (p.low - q.loc) / q.scale
915
    t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
916
    t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
917
    return t1 - t2
918

919

920
# TODO: Uniform-Laplace KL Divergence
921

922

923
@register_kl(Uniform, Normal)
924
def _kl_uniform_normal(p, q):
925
    common_term = p.high - p.low
926
    t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
927
    t2 = (common_term).pow(2) / 12
928
    t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
929
    return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
930

931

932
@register_kl(Uniform, Pareto)
933
def _kl_uniform_pareto(p, q):
934
    support_uniform = p.high - p.low
935
    t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
936
    t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
937
    result = t2 * (q.alpha + 1) - t1
938
    result[p.low < q.support.lower_bound] = inf
939
    return result
940

941

942
@register_kl(Independent, Independent)
943
def _kl_independent_independent(p, q):
944
    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
945
        raise NotImplementedError
946
    result = kl_divergence(p.base_dist, q.base_dist)
947
    return _sum_rightmost(result, p.reinterpreted_batch_ndims)
948

949

950
@register_kl(Cauchy, Cauchy)
951
def _kl_cauchy_cauchy(p, q):
952
    # From https://arxiv.org/abs/1905.10965
953
    t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log()
954
    t2 = (4 * p.scale * q.scale).log()
955
    return t1 - t2
956

957

958
def _add_kl_info():
959
    """Appends a list of implemented KL functions to the doc for kl_divergence."""
960
    rows = [
961
        "KL divergence is currently implemented for the following distribution pairs:"
962
    ]
963
    for p, q in sorted(
964
        _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__)
965
    ):
966
        rows.append(
967
            f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`"
968
        )
969
    kl_info = "\n\t".join(rows)
970
    if kl_divergence.__doc__:
971
        kl_divergence.__doc__ += kl_info  # type: ignore[operator]
972

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.