google-research

Форк
0
/
gradient_utils.py 
47 строк · 1.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utils for gradient alignment code."""
17

18
import jax
19
import jax.numpy as jnp
20

21

22
def tree_dot(tree_x, tree_y):
23
  a = jax.tree_util.tree_map(lambda x, y: jax.lax.dot(x.ravel(), y.ravel()),
24
                             tree_x, (tree_y))
25
  return jnp.sum(jnp.array(jax.tree_util.tree_flatten(a)[0]))
26

27

28
def tree_mult(tree_x, val_y):
29
  return jax.tree_util.tree_map(
30
      lambda x: x.ravel() * val_y, tree_x)
31

32

33
def tree_div(tree_x, val_y):
34
  return jax.tree_util.tree_map(
35
      lambda x: x.ravel() / val_y, tree_x)
36

37

38
def tree_diff(tree_x, tree_y):
39
  return jax.tree_util.tree_map(lambda x, y: x.ravel() - y.ravel(), tree_x,
40
                                (tree_y))
41

42

43
def tree_norm(tree_x):
44
  a = jax.tree_util.tree_map(
45
      lambda x: jnp.sum(jnp.square(x)), tree_x)
46
  b = jnp.sum(jnp.array(jax.tree_util.tree_flatten(a)[0]))
47
  return jnp.sqrt(b)
48

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.