google-research
76 строк · 2.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for logging, neural network activations, and initializations."""
17
18# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
19# pylint: disable=line-too-long, missing-docstring, g-importing-member
20import tensorflow.compat.v1 as tf21import numpy as np22import gin23from functools import partial24import functools25
26from weak_disentangle import tensorsketch as ts27
28
29@gin.configurable30def log(*args, debug=False):31if debug:32print(*args)33else:34tf.logging.info(" ".join(map(str, args)))35
36
37def reset_parameters(m):38m.reset_parameters()39
40
41# pylint: disable=invalid-name
42def add_act(m, Act):43m.act = Act()44m.out_hooks.update(dict(act=lambda self, x: self.act(x)))45
46
47def remove_act(m):48del m.act49del m.out_hooks["act"]50
51
52# pylint: disable=unused-argument
53@gin.configurable54def initializer(kernel, bias, method, layer):55if method == "pytorch":56pytorch_init(kernel, bias)57elif method == "keras":58keras_init(kernel, bias)59
60
61def pytorch_init(kernel, bias):62fan_in, _ = ts.utils.compute_fan(kernel)63limit = np.sqrt(1 / fan_in)64kernel.assign(tf.random.uniform(kernel.shape, -limit, limit))65
66if bias is not None:67bias.assign(tf.random.uniform(bias.shape, -limit, limit))68
69
70def keras_init(kernel, bias):71fan_in, fan_out = ts.utils.compute_fan(kernel)72limit = np.sqrt(6 / (fan_in + fan_out))73kernel.assign(tf.random.uniform(kernel.shape, -limit, limit))74
75if bias is not None:76bias.assign(tf.zeros(bias.shape))77