google-research

Форк
0
49 строк · 1.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""xm_utils"""
17

18
import jax.numpy as jnp
19
import numpy as np
20

21

22
def clip(x, clip_norm=1.0):
23
  divisor = jnp.maximum(jnp.linalg.norm(x) / clip_norm, 1.)
24
  return x / divisor
25

26

27
# To be used with jax.jit
28
def eval_step(wopt, test_x_np_list, hidden_dims, num_labels):
29
  nc = 0.0
30
  t = 0.0
31
  theta_np = jnp.reshape(wopt, (-1, hidden_dims))
32
  for l in range(num_labels):
33
    l_p = jnp.argmax(
34
        jnp.einsum('ld,nd->nl', theta_np, jnp.array(test_x_np_list[l])),
35
        axis=1)
36
    t += len(test_x_np_list[l])
37
    nc += jnp.sum(l_p == l)
38
  return nc / t
39

40

41
def to_flat_np(xs, labels, num_labels):
42
  xs_np = list(map(lambda x: x.numpy(), xs))
43
  labels = list(map(lambda x: x.numpy(), labels))
44
  x_np = np.concatenate(xs_np, axis=0)
45
  y_np = np.concatenate(labels, axis=0)
46
  x_list = [[] for _ in range(num_labels)]
47
  for x, y in zip(x_np, y_np):
48
    x_list[y].append(x)
49
  return x_list
50

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.