google-research
47 строк · 1.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Example for fishy based sampling."""
17# A simple example on how to handle sampled fisher for final classification
18# layer from: https://openreview.net/forum?id=cScb-RrBQC
19#
20import math21
22import jax23import jax.numpy as jnp24
25
26@jax.custom_vjp27def sampled_with_softmax(unused_rng, pre_act):28"""Custom autodiff wrapper for softmax."""29return pre_act30
31
32def sampled_with_softmax_fwd(rng, pre_act):33"""Custom forward autodiff wrapper for softmax."""34post_act = jax.nn.softmax(pre_act)35return pre_act, (post_act, pre_act, rng)36
37
38def sampled_with_softmax_bwd(res, unused_g):39"""Custom backward autodiff wrapper for softmax."""40post_act, pre_act, rng = res41post_act_sample = jax.random.categorical(rng, logits=pre_act)42out = post_act - jax.nn.one_hot(43post_act_sample, num_classes=post_act.shape[-1], dtype=jnp.float32)44out /= math.prod(post_act.shape[:-1])45return (None, out)46
47sampled_with_softmax.defvjp(sampled_with_softmax_fwd, sampled_with_softmax_bwd)48