google-research
49 строк · 1.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""xm_utils"""
17
18import jax.numpy as jnp19import numpy as np20
21
22def clip(x, clip_norm=1.0):23divisor = jnp.maximum(jnp.linalg.norm(x) / clip_norm, 1.)24return x / divisor25
26
27# To be used with jax.jit
28def eval_step(wopt, test_x_np_list, hidden_dims, num_labels):29nc = 0.030t = 0.031theta_np = jnp.reshape(wopt, (-1, hidden_dims))32for l in range(num_labels):33l_p = jnp.argmax(34jnp.einsum('ld,nd->nl', theta_np, jnp.array(test_x_np_list[l])),35axis=1)36t += len(test_x_np_list[l])37nc += jnp.sum(l_p == l)38return nc / t39
40
41def to_flat_np(xs, labels, num_labels):42xs_np = list(map(lambda x: x.numpy(), xs))43labels = list(map(lambda x: x.numpy(), labels))44x_np = np.concatenate(xs_np, axis=0)45y_np = np.concatenate(labels, axis=0)46x_list = [[] for _ in range(num_labels)]47for x, y in zip(x_np, y_np):48x_list[y].append(x)49return x_list50