google-research
47 строк · 1.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utils for gradient alignment code."""
17
18import jax19import jax.numpy as jnp20
21
22def tree_dot(tree_x, tree_y):23a = jax.tree_util.tree_map(lambda x, y: jax.lax.dot(x.ravel(), y.ravel()),24tree_x, (tree_y))25return jnp.sum(jnp.array(jax.tree_util.tree_flatten(a)[0]))26
27
28def tree_mult(tree_x, val_y):29return jax.tree_util.tree_map(30lambda x: x.ravel() * val_y, tree_x)31
32
33def tree_div(tree_x, val_y):34return jax.tree_util.tree_map(35lambda x: x.ravel() / val_y, tree_x)36
37
38def tree_diff(tree_x, tree_y):39return jax.tree_util.tree_map(lambda x, y: x.ravel() - y.ravel(), tree_x,40(tree_y))41
42
43def tree_norm(tree_x):44a = jax.tree_util.tree_map(45lambda x: jnp.sum(jnp.square(x)), tree_x)46b = jnp.sum(jnp.array(jax.tree_util.tree_flatten(a)[0]))47return jnp.sqrt(b)48