pytorch
971 строка · 31.0 Кб
1import math
2import warnings
3from functools import total_ordering
4from typing import Callable, Dict, Tuple, Type
5
6import torch
7from torch import inf
8
9from .bernoulli import Bernoulli
10from .beta import Beta
11from .binomial import Binomial
12from .categorical import Categorical
13from .cauchy import Cauchy
14from .continuous_bernoulli import ContinuousBernoulli
15from .dirichlet import Dirichlet
16from .distribution import Distribution
17from .exp_family import ExponentialFamily
18from .exponential import Exponential
19from .gamma import Gamma
20from .geometric import Geometric
21from .gumbel import Gumbel
22from .half_normal import HalfNormal
23from .independent import Independent
24from .laplace import Laplace
25from .lowrank_multivariate_normal import (
26_batch_lowrank_logdet,
27_batch_lowrank_mahalanobis,
28LowRankMultivariateNormal,
29)
30from .multivariate_normal import _batch_mahalanobis, MultivariateNormal
31from .normal import Normal
32from .one_hot_categorical import OneHotCategorical
33from .pareto import Pareto
34from .poisson import Poisson
35from .transformed_distribution import TransformedDistribution
36from .uniform import Uniform
37from .utils import _sum_rightmost, euler_constant as _euler_gamma
38
39_KL_REGISTRY: Dict[
40Tuple[Type, Type], Callable
41] = {} # Source of truth mapping a few general (type, type) pairs to functions.
42_KL_MEMOIZE: Dict[
43Tuple[Type, Type], Callable
44] = {} # Memoized version mapping many specific (type, type) pairs to functions.
45
46__all__ = ["register_kl", "kl_divergence"]
47
48
49def register_kl(type_p, type_q):
50"""
51Decorator to register a pairwise function with :meth:`kl_divergence`.
52Usage::
53
54@register_kl(Normal, Normal)
55def kl_normal_normal(p, q):
56# insert implementation here
57
58Lookup returns the most specific (type,type) match ordered by subclass. If
59the match is ambiguous, a `RuntimeWarning` is raised. For example to
60resolve the ambiguous situation::
61
62@register_kl(BaseP, DerivedQ)
63def kl_version1(p, q): ...
64@register_kl(DerivedP, BaseQ)
65def kl_version2(p, q): ...
66
67you should register a third most-specific implementation, e.g.::
68
69register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie.
70
71Args:
72type_p (type): A subclass of :class:`~torch.distributions.Distribution`.
73type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
74"""
75if not isinstance(type_p, type) and issubclass(type_p, Distribution):
76raise TypeError(
77f"Expected type_p to be a Distribution subclass but got {type_p}"
78)
79if not isinstance(type_q, type) and issubclass(type_q, Distribution):
80raise TypeError(
81f"Expected type_q to be a Distribution subclass but got {type_q}"
82)
83
84def decorator(fun):
85_KL_REGISTRY[type_p, type_q] = fun
86_KL_MEMOIZE.clear() # reset since lookup order may have changed
87return fun
88
89return decorator
90
91
92@total_ordering
93class _Match:
94__slots__ = ["types"]
95
96def __init__(self, *types):
97self.types = types
98
99def __eq__(self, other):
100return self.types == other.types
101
102def __le__(self, other):
103for x, y in zip(self.types, other.types):
104if not issubclass(x, y):
105return False
106if x is not y:
107break
108return True
109
110
111def _dispatch_kl(type_p, type_q):
112"""
113Find the most specific approximate match, assuming single inheritance.
114"""
115matches = [
116(super_p, super_q)
117for super_p, super_q in _KL_REGISTRY
118if issubclass(type_p, super_p) and issubclass(type_q, super_q)
119]
120if not matches:
121return NotImplemented
122# Check that the left- and right- lexicographic orders agree.
123# mypy isn't smart enough to know that _Match implements __lt__
124# see: https://github.com/python/typing/issues/760#issuecomment-710670503
125left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore[type-var]
126right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore[type-var]
127left_fun = _KL_REGISTRY[left_p, left_q]
128right_fun = _KL_REGISTRY[right_p, right_q]
129if left_fun is not right_fun:
130warnings.warn(
131"Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format(
132type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__
133),
134RuntimeWarning,
135)
136return left_fun
137
138
139def _infinite_like(tensor):
140"""
141Helper function for obtaining infinite KL Divergence throughout
142"""
143return torch.full_like(tensor, inf)
144
145
146def _x_log_x(tensor):
147"""
148Utility function for calculating x log x
149"""
150return tensor * tensor.log()
151
152
153def _batch_trace_XXT(bmat):
154"""
155Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions
156"""
157n = bmat.size(-1)
158m = bmat.size(-2)
159flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
160return flat_trace.reshape(bmat.shape[:-2])
161
162
163def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor:
164r"""
165Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
166
167.. math::
168
169KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
170
171Args:
172p (Distribution): A :class:`~torch.distributions.Distribution` object.
173q (Distribution): A :class:`~torch.distributions.Distribution` object.
174
175Returns:
176Tensor: A batch of KL divergences of shape `batch_shape`.
177
178Raises:
179NotImplementedError: If the distribution types have not been registered via
180:meth:`register_kl`.
181"""
182try:
183fun = _KL_MEMOIZE[type(p), type(q)]
184except KeyError:
185fun = _dispatch_kl(type(p), type(q))
186_KL_MEMOIZE[type(p), type(q)] = fun
187if fun is NotImplemented:
188raise NotImplementedError(
189f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}"
190)
191return fun(p, q)
192
193
194################################################################################
195# KL Divergence Implementations
196################################################################################
197
198# Same distributions
199
200
201@register_kl(Bernoulli, Bernoulli)
202def _kl_bernoulli_bernoulli(p, q):
203t1 = p.probs * (
204torch.nn.functional.softplus(-q.logits)
205- torch.nn.functional.softplus(-p.logits)
206)
207t1[q.probs == 0] = inf
208t1[p.probs == 0] = 0
209t2 = (1 - p.probs) * (
210torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits)
211)
212t2[q.probs == 1] = inf
213t2[p.probs == 1] = 0
214return t1 + t2
215
216
217@register_kl(Beta, Beta)
218def _kl_beta_beta(p, q):
219sum_params_p = p.concentration1 + p.concentration0
220sum_params_q = q.concentration1 + q.concentration0
221t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
222t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
223t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
224t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
225t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
226return t1 - t2 + t3 + t4 + t5
227
228
229@register_kl(Binomial, Binomial)
230def _kl_binomial_binomial(p, q):
231# from https://math.stackexchange.com/questions/2214993/
232# kullback-leibler-divergence-for-binomial-distributions-p-and-q
233if (p.total_count < q.total_count).any():
234raise NotImplementedError(
235"KL between Binomials where q.total_count > p.total_count is not implemented"
236)
237kl = p.total_count * (
238p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()
239)
240inf_idxs = p.total_count > q.total_count
241kl[inf_idxs] = _infinite_like(kl[inf_idxs])
242return kl
243
244
245@register_kl(Categorical, Categorical)
246def _kl_categorical_categorical(p, q):
247t = p.probs * (p.logits - q.logits)
248t[(q.probs == 0).expand_as(t)] = inf
249t[(p.probs == 0).expand_as(t)] = 0
250return t.sum(-1)
251
252
253@register_kl(ContinuousBernoulli, ContinuousBernoulli)
254def _kl_continuous_bernoulli_continuous_bernoulli(p, q):
255t1 = p.mean * (p.logits - q.logits)
256t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs)
257t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs)
258return t1 + t2 + t3
259
260
261@register_kl(Dirichlet, Dirichlet)
262def _kl_dirichlet_dirichlet(p, q):
263# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
264sum_p_concentration = p.concentration.sum(-1)
265sum_q_concentration = q.concentration.sum(-1)
266t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
267t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
268t3 = p.concentration - q.concentration
269t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
270return t1 - t2 + (t3 * t4).sum(-1)
271
272
273@register_kl(Exponential, Exponential)
274def _kl_exponential_exponential(p, q):
275rate_ratio = q.rate / p.rate
276t1 = -rate_ratio.log()
277return t1 + rate_ratio - 1
278
279
280@register_kl(ExponentialFamily, ExponentialFamily)
281def _kl_expfamily_expfamily(p, q):
282if not type(p) == type(q):
283raise NotImplementedError(
284"The cross KL-divergence between different exponential families cannot \
285be computed using Bregman divergences"
286)
287p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
288q_nparams = q._natural_params
289lg_normal = p._log_normalizer(*p_nparams)
290gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
291result = q._log_normalizer(*q_nparams) - lg_normal
292for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
293term = (qnp - pnp) * g
294result -= _sum_rightmost(term, len(q.event_shape))
295return result
296
297
298@register_kl(Gamma, Gamma)
299def _kl_gamma_gamma(p, q):
300t1 = q.concentration * (p.rate / q.rate).log()
301t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
302t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
303t4 = (q.rate - p.rate) * (p.concentration / p.rate)
304return t1 + t2 + t3 + t4
305
306
307@register_kl(Gumbel, Gumbel)
308def _kl_gumbel_gumbel(p, q):
309ct1 = p.scale / q.scale
310ct2 = q.loc / q.scale
311ct3 = p.loc / q.scale
312t1 = -ct1.log() - ct2 + ct3
313t2 = ct1 * _euler_gamma
314t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
315return t1 + t2 + t3 - (1 + _euler_gamma)
316
317
318@register_kl(Geometric, Geometric)
319def _kl_geometric_geometric(p, q):
320return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
321
322
323@register_kl(HalfNormal, HalfNormal)
324def _kl_halfnormal_halfnormal(p, q):
325return _kl_normal_normal(p.base_dist, q.base_dist)
326
327
328@register_kl(Laplace, Laplace)
329def _kl_laplace_laplace(p, q):
330# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
331scale_ratio = p.scale / q.scale
332loc_abs_diff = (p.loc - q.loc).abs()
333t1 = -scale_ratio.log()
334t2 = loc_abs_diff / q.scale
335t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
336return t1 + t2 + t3 - 1
337
338
339@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
340def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
341if p.event_shape != q.event_shape:
342raise ValueError(
343"KL-divergence between two Low Rank Multivariate Normals with\
344different event shapes cannot be computed"
345)
346
347term1 = _batch_lowrank_logdet(
348q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
349) - _batch_lowrank_logdet(
350p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
351)
352term3 = _batch_lowrank_mahalanobis(
353q._unbroadcasted_cov_factor,
354q._unbroadcasted_cov_diag,
355q.loc - p.loc,
356q._capacitance_tril,
357)
358# Expands term2 according to
359# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
360# = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
361qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
362A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
363term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
364term22 = _batch_trace_XXT(
365p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
366)
367term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
368term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
369term2 = term21 + term22 - term23 - term24
370return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
371
372
373@register_kl(MultivariateNormal, LowRankMultivariateNormal)
374def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
375if p.event_shape != q.event_shape:
376raise ValueError(
377"KL-divergence between two (Low Rank) Multivariate Normals with\
378different event shapes cannot be computed"
379)
380
381term1 = _batch_lowrank_logdet(
382q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril
383) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
384term3 = _batch_lowrank_mahalanobis(
385q._unbroadcasted_cov_factor,
386q._unbroadcasted_cov_diag,
387q.loc - p.loc,
388q._capacitance_tril,
389)
390# Expands term2 according to
391# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
392# = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
393qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2)
394A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
395term21 = _batch_trace_XXT(
396p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)
397)
398term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
399term2 = term21 - term22
400return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
401
402
403@register_kl(LowRankMultivariateNormal, MultivariateNormal)
404def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
405if p.event_shape != q.event_shape:
406raise ValueError(
407"KL-divergence between two (Low Rank) Multivariate Normals with\
408different event shapes cannot be computed"
409)
410
411term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
412-1
413) - _batch_lowrank_logdet(
414p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril
415)
416term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
417# Expands term2 according to
418# inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
419combined_batch_shape = torch._C._infer_size(
420q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2]
421)
422n = p.event_shape[0]
423q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
424p_cov_factor = p._unbroadcasted_cov_factor.expand(
425combined_batch_shape + (n, p.cov_factor.size(-1))
426)
427p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand(
428combined_batch_shape + (n, n)
429)
430term21 = _batch_trace_XXT(
431torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False)
432)
433term22 = _batch_trace_XXT(
434torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False)
435)
436term2 = term21 + term22
437return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
438
439
440@register_kl(MultivariateNormal, MultivariateNormal)
441def _kl_multivariatenormal_multivariatenormal(p, q):
442# From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
443if p.event_shape != q.event_shape:
444raise ValueError(
445"KL-divergence between two Multivariate Normals with\
446different event shapes cannot be computed"
447)
448
449half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(
450-1
451) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
452combined_batch_shape = torch._C._infer_size(
453q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2]
454)
455n = p.event_shape[0]
456q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
457p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
458term2 = _batch_trace_XXT(
459torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False)
460)
461term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
462return half_term1 + 0.5 * (term2 + term3 - n)
463
464
465@register_kl(Normal, Normal)
466def _kl_normal_normal(p, q):
467var_ratio = (p.scale / q.scale).pow(2)
468t1 = ((p.loc - q.loc) / q.scale).pow(2)
469return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
470
471
472@register_kl(OneHotCategorical, OneHotCategorical)
473def _kl_onehotcategorical_onehotcategorical(p, q):
474return _kl_categorical_categorical(p._categorical, q._categorical)
475
476
477@register_kl(Pareto, Pareto)
478def _kl_pareto_pareto(p, q):
479# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
480scale_ratio = p.scale / q.scale
481alpha_ratio = q.alpha / p.alpha
482t1 = q.alpha * scale_ratio.log()
483t2 = -alpha_ratio.log()
484result = t1 + t2 + alpha_ratio - 1
485result[p.support.lower_bound < q.support.lower_bound] = inf
486return result
487
488
489@register_kl(Poisson, Poisson)
490def _kl_poisson_poisson(p, q):
491return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
492
493
494@register_kl(TransformedDistribution, TransformedDistribution)
495def _kl_transformed_transformed(p, q):
496if p.transforms != q.transforms:
497raise NotImplementedError
498if p.event_shape != q.event_shape:
499raise NotImplementedError
500return kl_divergence(p.base_dist, q.base_dist)
501
502
503@register_kl(Uniform, Uniform)
504def _kl_uniform_uniform(p, q):
505result = ((q.high - q.low) / (p.high - p.low)).log()
506result[(q.low > p.low) | (q.high < p.high)] = inf
507return result
508
509
510# Different distributions
511@register_kl(Bernoulli, Poisson)
512def _kl_bernoulli_poisson(p, q):
513return -p.entropy() - (p.probs * q.rate.log() - q.rate)
514
515
516@register_kl(Beta, ContinuousBernoulli)
517def _kl_beta_continuous_bernoulli(p, q):
518return (
519-p.entropy()
520- p.mean * q.logits
521- torch.log1p(-q.probs)
522- q._cont_bern_log_norm()
523)
524
525
526@register_kl(Beta, Pareto)
527def _kl_beta_infinity(p, q):
528return _infinite_like(p.concentration1)
529
530
531@register_kl(Beta, Exponential)
532def _kl_beta_exponential(p, q):
533return (
534-p.entropy()
535- q.rate.log()
536+ q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
537)
538
539
540@register_kl(Beta, Gamma)
541def _kl_beta_gamma(p, q):
542t1 = -p.entropy()
543t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
544t3 = (q.concentration - 1) * (
545p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma()
546)
547t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
548return t1 + t2 - t3 + t4
549
550
551# TODO: Add Beta-Laplace KL Divergence
552
553
554@register_kl(Beta, Normal)
555def _kl_beta_normal(p, q):
556E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
557var_normal = q.scale.pow(2)
558t1 = -p.entropy()
559t2 = 0.5 * (var_normal * 2 * math.pi).log()
560t3 = (
561E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1)
562+ E_beta.pow(2)
563) * 0.5
564t4 = q.loc * E_beta
565t5 = q.loc.pow(2) * 0.5
566return t1 + t2 + (t3 - t4 + t5) / var_normal
567
568
569@register_kl(Beta, Uniform)
570def _kl_beta_uniform(p, q):
571result = -p.entropy() + (q.high - q.low).log()
572result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
573return result
574
575
576# Note that the KL between a ContinuousBernoulli and Beta has no closed form
577
578
579@register_kl(ContinuousBernoulli, Pareto)
580def _kl_continuous_bernoulli_infinity(p, q):
581return _infinite_like(p.probs)
582
583
584@register_kl(ContinuousBernoulli, Exponential)
585def _kl_continuous_bernoulli_exponential(p, q):
586return -p.entropy() - torch.log(q.rate) + q.rate * p.mean
587
588
589# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
590# TODO: Add ContinuousBernoulli-Laplace KL Divergence
591
592
593@register_kl(ContinuousBernoulli, Normal)
594def _kl_continuous_bernoulli_normal(p, q):
595t1 = -p.entropy()
596t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log(
597q.scale
598)
599t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / (
6002.0 * torch.square(q.scale)
601)
602return t1 + t2 + t3
603
604
605@register_kl(ContinuousBernoulli, Uniform)
606def _kl_continuous_bernoulli_uniform(p, q):
607result = -p.entropy() + (q.high - q.low).log()
608return torch.where(
609torch.max(
610torch.ge(q.low, p.support.lower_bound),
611torch.le(q.high, p.support.upper_bound),
612),
613torch.ones_like(result) * inf,
614result,
615)
616
617
618@register_kl(Exponential, Beta)
619@register_kl(Exponential, ContinuousBernoulli)
620@register_kl(Exponential, Pareto)
621@register_kl(Exponential, Uniform)
622def _kl_exponential_infinity(p, q):
623return _infinite_like(p.rate)
624
625
626@register_kl(Exponential, Gamma)
627def _kl_exponential_gamma(p, q):
628ratio = q.rate / p.rate
629t1 = -q.concentration * torch.log(ratio)
630return (
631t1
632+ ratio
633+ q.concentration.lgamma()
634+ q.concentration * _euler_gamma
635- (1 + _euler_gamma)
636)
637
638
639@register_kl(Exponential, Gumbel)
640def _kl_exponential_gumbel(p, q):
641scale_rate_prod = p.rate * q.scale
642loc_scale_ratio = q.loc / q.scale
643t1 = scale_rate_prod.log() - 1
644t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
645t3 = scale_rate_prod.reciprocal()
646return t1 - loc_scale_ratio + t2 + t3
647
648
649# TODO: Add Exponential-Laplace KL Divergence
650
651
652@register_kl(Exponential, Normal)
653def _kl_exponential_normal(p, q):
654var_normal = q.scale.pow(2)
655rate_sqr = p.rate.pow(2)
656t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
657t2 = rate_sqr.reciprocal()
658t3 = q.loc / p.rate
659t4 = q.loc.pow(2) * 0.5
660return t1 - 1 + (t2 - t3 + t4) / var_normal
661
662
663@register_kl(Gamma, Beta)
664@register_kl(Gamma, ContinuousBernoulli)
665@register_kl(Gamma, Pareto)
666@register_kl(Gamma, Uniform)
667def _kl_gamma_infinity(p, q):
668return _infinite_like(p.concentration)
669
670
671@register_kl(Gamma, Exponential)
672def _kl_gamma_exponential(p, q):
673return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
674
675
676@register_kl(Gamma, Gumbel)
677def _kl_gamma_gumbel(p, q):
678beta_scale_prod = p.rate * q.scale
679loc_scale_ratio = q.loc / q.scale
680t1 = (
681(p.concentration - 1) * p.concentration.digamma()
682- p.concentration.lgamma()
683- p.concentration
684)
685t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
686t3 = (
687torch.exp(loc_scale_ratio)
688* (1 + beta_scale_prod.reciprocal()).pow(-p.concentration)
689- loc_scale_ratio
690)
691return t1 + t2 + t3
692
693
694# TODO: Add Gamma-Laplace KL Divergence
695
696
697@register_kl(Gamma, Normal)
698def _kl_gamma_normal(p, q):
699var_normal = q.scale.pow(2)
700beta_sqr = p.rate.pow(2)
701t1 = (
7020.5 * torch.log(beta_sqr * var_normal * 2 * math.pi)
703- p.concentration
704- p.concentration.lgamma()
705)
706t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
707t3 = q.loc * p.concentration / p.rate
708t4 = 0.5 * q.loc.pow(2)
709return (
710t1
711+ (p.concentration - 1) * p.concentration.digamma()
712+ (t2 - t3 + t4) / var_normal
713)
714
715
716@register_kl(Gumbel, Beta)
717@register_kl(Gumbel, ContinuousBernoulli)
718@register_kl(Gumbel, Exponential)
719@register_kl(Gumbel, Gamma)
720@register_kl(Gumbel, Pareto)
721@register_kl(Gumbel, Uniform)
722def _kl_gumbel_infinity(p, q):
723return _infinite_like(p.loc)
724
725
726# TODO: Add Gumbel-Laplace KL Divergence
727
728
729@register_kl(Gumbel, Normal)
730def _kl_gumbel_normal(p, q):
731param_ratio = p.scale / q.scale
732t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
733t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
734t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
735return -t1 + t2 + t3 - (_euler_gamma + 1)
736
737
738@register_kl(Laplace, Beta)
739@register_kl(Laplace, ContinuousBernoulli)
740@register_kl(Laplace, Exponential)
741@register_kl(Laplace, Gamma)
742@register_kl(Laplace, Pareto)
743@register_kl(Laplace, Uniform)
744def _kl_laplace_infinity(p, q):
745return _infinite_like(p.loc)
746
747
748@register_kl(Laplace, Normal)
749def _kl_laplace_normal(p, q):
750var_normal = q.scale.pow(2)
751scale_sqr_var_ratio = p.scale.pow(2) / var_normal
752t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
753t2 = 0.5 * p.loc.pow(2)
754t3 = p.loc * q.loc
755t4 = 0.5 * q.loc.pow(2)
756return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
757
758
759@register_kl(Normal, Beta)
760@register_kl(Normal, ContinuousBernoulli)
761@register_kl(Normal, Exponential)
762@register_kl(Normal, Gamma)
763@register_kl(Normal, Pareto)
764@register_kl(Normal, Uniform)
765def _kl_normal_infinity(p, q):
766return _infinite_like(p.loc)
767
768
769@register_kl(Normal, Gumbel)
770def _kl_normal_gumbel(p, q):
771mean_scale_ratio = p.loc / q.scale
772var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
773loc_scale_ratio = q.loc / q.scale
774t1 = var_scale_sqr_ratio.log() * 0.5
775t2 = mean_scale_ratio - loc_scale_ratio
776t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
777return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
778
779
780@register_kl(Normal, Laplace)
781def _kl_normal_laplace(p, q):
782loc_diff = p.loc - q.loc
783scale_ratio = p.scale / q.scale
784loc_diff_scale_ratio = loc_diff / p.scale
785t1 = torch.log(scale_ratio)
786t2 = (
787math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2))
788)
789t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio)
790return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi)))
791
792
793@register_kl(Pareto, Beta)
794@register_kl(Pareto, ContinuousBernoulli)
795@register_kl(Pareto, Uniform)
796def _kl_pareto_infinity(p, q):
797return _infinite_like(p.scale)
798
799
800@register_kl(Pareto, Exponential)
801def _kl_pareto_exponential(p, q):
802scale_rate_prod = p.scale * q.rate
803t1 = (p.alpha / scale_rate_prod).log()
804t2 = p.alpha.reciprocal()
805t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
806result = t1 - t2 + t3 - 1
807result[p.alpha <= 1] = inf
808return result
809
810
811@register_kl(Pareto, Gamma)
812def _kl_pareto_gamma(p, q):
813common_term = p.scale.log() + p.alpha.reciprocal()
814t1 = p.alpha.log() - common_term
815t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
816t3 = (1 - q.concentration) * common_term
817t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
818result = t1 + t2 + t3 + t4 - 1
819result[p.alpha <= 1] = inf
820return result
821
822
823# TODO: Add Pareto-Laplace KL Divergence
824
825
826@register_kl(Pareto, Normal)
827def _kl_pareto_normal(p, q):
828var_normal = 2 * q.scale.pow(2)
829common_term = p.scale / (p.alpha - 1)
830t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
831t2 = p.alpha.reciprocal()
832t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
833t4 = (p.alpha * common_term - q.loc).pow(2)
834result = t1 - t2 + (t3 + t4) / var_normal - 1
835result[p.alpha <= 2] = inf
836return result
837
838
839@register_kl(Poisson, Bernoulli)
840@register_kl(Poisson, Binomial)
841def _kl_poisson_infinity(p, q):
842return _infinite_like(p.rate)
843
844
845@register_kl(Uniform, Beta)
846def _kl_uniform_beta(p, q):
847common_term = p.high - p.low
848t1 = torch.log(common_term)
849t2 = (
850(q.concentration1 - 1)
851* (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
852/ common_term
853)
854t3 = (
855(q.concentration0 - 1)
856* (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term)
857/ common_term
858)
859t4 = (
860q.concentration1.lgamma()
861+ q.concentration0.lgamma()
862- (q.concentration1 + q.concentration0).lgamma()
863)
864result = t3 + t4 - t1 - t2
865result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
866return result
867
868
869@register_kl(Uniform, ContinuousBernoulli)
870def _kl_uniform_continuous_bernoulli(p, q):
871result = (
872-p.entropy()
873- p.mean * q.logits
874- torch.log1p(-q.probs)
875- q._cont_bern_log_norm()
876)
877return torch.where(
878torch.max(
879torch.ge(p.high, q.support.upper_bound),
880torch.le(p.low, q.support.lower_bound),
881),
882torch.ones_like(result) * inf,
883result,
884)
885
886
887@register_kl(Uniform, Exponential)
888def _kl_uniform_exponetial(p, q):
889result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
890result[p.low < q.support.lower_bound] = inf
891return result
892
893
894@register_kl(Uniform, Gamma)
895def _kl_uniform_gamma(p, q):
896common_term = p.high - p.low
897t1 = common_term.log()
898t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
899t3 = (
900(1 - q.concentration)
901* (_x_log_x(p.high) - _x_log_x(p.low) - common_term)
902/ common_term
903)
904t4 = q.rate * (p.high + p.low) / 2
905result = -t1 + t2 + t3 + t4
906result[p.low < q.support.lower_bound] = inf
907return result
908
909
910@register_kl(Uniform, Gumbel)
911def _kl_uniform_gumbel(p, q):
912common_term = q.scale / (p.high - p.low)
913high_loc_diff = (p.high - q.loc) / q.scale
914low_loc_diff = (p.low - q.loc) / q.scale
915t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
916t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
917return t1 - t2
918
919
920# TODO: Uniform-Laplace KL Divergence
921
922
923@register_kl(Uniform, Normal)
924def _kl_uniform_normal(p, q):
925common_term = p.high - p.low
926t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
927t2 = (common_term).pow(2) / 12
928t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
929return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
930
931
932@register_kl(Uniform, Pareto)
933def _kl_uniform_pareto(p, q):
934support_uniform = p.high - p.low
935t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
936t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
937result = t2 * (q.alpha + 1) - t1
938result[p.low < q.support.lower_bound] = inf
939return result
940
941
942@register_kl(Independent, Independent)
943def _kl_independent_independent(p, q):
944if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
945raise NotImplementedError
946result = kl_divergence(p.base_dist, q.base_dist)
947return _sum_rightmost(result, p.reinterpreted_batch_ndims)
948
949
950@register_kl(Cauchy, Cauchy)
951def _kl_cauchy_cauchy(p, q):
952# From https://arxiv.org/abs/1905.10965
953t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log()
954t2 = (4 * p.scale * q.scale).log()
955return t1 - t2
956
957
958def _add_kl_info():
959"""Appends a list of implemented KL functions to the doc for kl_divergence."""
960rows = [
961"KL divergence is currently implemented for the following distribution pairs:"
962]
963for p, q in sorted(
964_KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__)
965):
966rows.append(
967f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`"
968)
969kl_info = "\n\t".join(rows)
970if kl_divergence.__doc__:
971kl_divergence.__doc__ += kl_info # type: ignore[operator]
972