google-research

Форк
0

README.md

Sparse soft top-k

An implementation of the following paper:

Fast, Differentiable and Sparse Top-k: a Convex Analysis Perspective. Michael E. Sander, Joan Puigcerver, Josip Djolonga, Gabriel Peyré, Mathieu Blondel. arXiv:2302.01425

Figure: Illustration of our differentiable and sparse top-k mask. For $$k=2$$, we consider $$x(s) = (3, 1, -1 + s, s) $$ and plot $$\mathrm{topkmask}(x(s))_2 + \mathrm{topkmask}(x(s))_3$$ as a function of $$s$$. We compare the hard version (no regularization) with our proposed operator using $$p$$-norm regularization: $$p=2$$ leads to differentiable operator with kinks; $$p=4/3$$, leads to a smooth operator. Both operators are sparse: they are exactly 0 for some values of $$s$$.

Dependencies

  • JAX
  • Numba

Overview

We propose two fast implementations of sparse yet differentiable ("soft") top-k operators.

  • The top-k mask operator returns a bit-vector encoding whether each value is within the top-k values or not:

$$ \mathrm{topkmask}(x)i = 1, \text{if } x{i} \text{ is in the k biggest values of } x \text{ and } 0 \text{ otherwise}. $$

  • The top-k operator in magnitude returns the values themselves if they are within the top-k absolute values or 0 otherwise:

$$ \mathrm{topkmag}(x) = x \circ \mathrm{topkmask}(|x|). $$

Our relaxed operators are obtained leveraging p-norm regularizations, that trade sparsity for smoothness. We provide two implementations of these smooth operators:

  • One implementation is based on the Pool Adjacent Violators (PAV) algorithm. PAV returns an exact solution and supports arbitrary regularization (with linear complexity in the case $$p=2$$ or $$p=4/3$$). It relies on a numba implementation.

  • One implementation is based on Dykstra's projection algorithm. It returns an approximate solution. Dykstra's projection algorithm can be faster than PAV on GPU/TPU, and is implemented in pure JAX.

Example

>>> import sparse_soft_topk
# Define input values
>>> values = jnp.array([-5., -2., 3., 1.])
# First, we can return the top-2 mask of the vector values
>>> print("soft top-k mask with PAV: ",sparse_soft_topk.sparse_soft_topk_mask_pav(values, k=2, l=1e-2))
soft top-k mask with PAV: [0. 0. 0.99999714 0.99999714]

Note that our output is sparse as it is exactly 0 outside the top-2 values.

# Similarly, we can also consider the top-2 in magnitude
>>> print("soft top-k in magnitude with PAV: ", sparse_soft_topk.sparse_soft_topk_mag_pav(values, k=2, l=1e-2))
soft top-k in magnitude with PAV: [-5.000036 -0. 2.9999416 0. ]

Alternatively, one can use Dykstra's projection algorithm to compute an approximation of the same operators:

>>> print("soft top-k mask with Dykstra: ", sparse_soft_topk.sparse_soft_topk_mask_dykstra(values, k=2, l=1e-2))
soft top-k mask with Dykstra: [0. 0. 0.99999905 0.99999905]
>>> print("soft top-k in magnitude with Dykstra: ", sparse_soft_topk.sparse_soft_topk_mag_dykstra(values, k=2, l=1e-2))
soft top-k in magnitude with Dykstra: [-4.9999804 -0. 2.999993 0. ]

One can compare the output with the hard counterpart operators

>>> print("hard top-k mask: ", sparse_soft_topk.hard_topk_mask(values, k=2))
hard top-k mask: [0. 0. 1. 1.]
>>> print("hard top-k in magnitude: ", sparse_soft_topk.hard_topk_mag(values, k=2))
hard top-k in magnitude: [-5. -0. 3. 0.]

Batched data

Our implementation using PAV algorithm is based on a pure_callback with numba just in time compilation. It supports arbitrary shaped inputs without vmaping. The top-k is applied along the last dimension.

>>> values = (5 * jax.random.normal(jax.random.PRNGKey(2), (2, 3, 4))).astype(jnp.int32)
>>> print("values: \n", values)
values:
[[[-11 1 4 -7]
[ 5 1 4 1]
[-10 -3 4 -6]]
[[ 7 -6 -8 -4]
[ -1 9 8 4]
[ 0 0 -5 5]]]
>>> print("soft top-k mask with PAV: \n", sparse_soft_topk.sparse_soft_topk_mask_pav(values, k=2, l=1e-2))
soft top-k mask with PAV:
[[[0. 0.99999714 0.99999714 0. ]
[1.0000687 0. 0.99999714 0. ]
[0. 0.99999714 0.99999714 0. ]]
[[1.0000687 0. 0. 1.0000687 ]
[0. 1.0000687 1.0000687 0. ]
[0.49999988 0.49999988 0. 1.0000687 ]]]
>>> print("soft top-k in magnitude with PAV: \n", sparse_soft_topk.sparse_soft_topk_mag_pav(values, k=2, l=1e-2))
soft top-k in magnitude with PAV:
[[[-10.999428 0. 0. -7.0002456]
[ 5.000036 0. 3.9999368 0. ]
[-10.000449 -0. 0. -5.999853 ]]
[[ 7.0002456 -0. -7.99995 -0. ]
[ -0. 8.999435 7.99995 0. ]
[ 0. 0. -5.000036 5.000036 ]]]

Our Dykstra implementation is in pure JAX, and one can simply use vmap.

>>> soft_topk_mask_fn = functools.partial(
sparse_soft_topk.sparse_soft_topk_mask_dykstra,
k=2, num_iter=100, l=1e-2)
>>> print("soft top-k mask with Dykstra: \n", jax.vmap(jax.vmap(soft_topk_mask_fn), in_axes=1, out_axes=1)(values))
soft top-k mask with Dykstra:
[[[0. 0.99999905 0.99999905 0. ]
[1.0000229 0. 0.99999905 0. ]
[0. 0.99999905 0.99999905 0. ]]
[[1.0000229 0. 0. 1.0000229 ]
[0. 1.0000229 1.0000229 0. ]
[0.5 0.5 0. 1.0000229 ]]]
>>> soft_topk_mag_fn = functools.partial(
sparse_soft_topk.sparse_soft_topk_mag_dykstra,
k=2, num_iter=100, l=1e-2)
>>> print("soft top-k in magnitude with Dykstra: \n", jax.vmap(jax.vmap(soft_topk_mag_fn), in_axes=1, out_axes=1)(values))
soft top-k in magnitude with Dykstra:
[[[-10.999967 0. 0. -6.999992 ]
[ 4.9999804 0. 3.9999988 0. ]
[ -9.999961 -0. 0. -5.999986 ]]
[[ 6.999992 -0. -7.9999976 -0. ]
[ -0. 8.999955 7.9999976 0. ]
[ 0. 0. -4.9999804 4.9999804]]]

Differentiability

Our operators are differentiable. By default, the top-k with PAV algorithm uses a p-norm regularization with p=4/3, which leads to a continuously differentiable operator. Dykstra's projection algorithm works for p=2, and leads to an almost differentiable everywhere operator.

>>> k = 4
>>> l = .5
# We define two scalar functions: one with our sparse operator, and the second with their hard counterparts.
>>> @jax.jit
>>> def sparse_soft_f(x):
>>> out = sparse_soft_topk.sparse_soft_topk_mask_pav(x, k=k, l=l) + sparse_soft_topk.sparse_soft_topk_mask_pav(x**2, k=k, l=l)
>>> return out.sum()
>>> @jax.jit
>>> def hard_f(x):
>>> return (sparse_soft_topk.hard_topk_mask(x, k=k) + sparse_soft_topk.hard_topk_mask(x**2, k=k)).sum()
>>> x = jax.random.normal(jax.random.PRNGKey(2), (10,))

The hard function has 0 derivatives everywhere:

>>> print(jax.grad(hard_f)(x))
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

whereas our soft operator has non-zero derivatives:

>>> print(jax.grad(sparse_soft_f)(x))
[ 0.0000000e+00 -2.2090350e-07 4.7683716e-07 0.0000000e+00
-3.6416495e-07 -3.7008726e-07 -1.1920929e-07 -1.7615172e-07
0.0000000e+00 -1.7881393e-07]

Cite

If you use this code in your project, please cite:

Michael E. Sander, Joan Puigcerver, Josip Djolonga, Gabriel Peyré, Mathieu Blondel
Fast, Differentiable and Sparse Top-k: a Convex Analysis Perspective
arXiv preprint arXiv:2302.01425
https://arxiv.org/abs/2302.01425

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.