google-research

Форк
0
76 строк · 2.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for logging, neural network activations, and initializations."""
17

18
# pylint: disable=g-bad-import-order, unused-import, g-multiple-import
19
# pylint: disable=line-too-long, missing-docstring, g-importing-member
20
import tensorflow.compat.v1 as tf
21
import numpy as np
22
import gin
23
from functools import partial
24
import functools
25

26
from weak_disentangle import tensorsketch as ts
27

28

29
@gin.configurable
30
def log(*args, debug=False):
31
  if debug:
32
    print(*args)
33
  else:
34
    tf.logging.info(" ".join(map(str, args)))
35

36

37
def reset_parameters(m):
38
  m.reset_parameters()
39

40

41
# pylint: disable=invalid-name
42
def add_act(m, Act):
43
  m.act = Act()
44
  m.out_hooks.update(dict(act=lambda self, x: self.act(x)))
45

46

47
def remove_act(m):
48
  del m.act
49
  del m.out_hooks["act"]
50

51

52
# pylint: disable=unused-argument
53
@gin.configurable
54
def initializer(kernel, bias, method, layer):
55
  if method == "pytorch":
56
    pytorch_init(kernel, bias)
57
  elif method == "keras":
58
    keras_init(kernel, bias)
59

60

61
def pytorch_init(kernel, bias):
62
  fan_in, _ = ts.utils.compute_fan(kernel)
63
  limit = np.sqrt(1 / fan_in)
64
  kernel.assign(tf.random.uniform(kernel.shape, -limit, limit))
65

66
  if bias is not None:
67
    bias.assign(tf.random.uniform(bias.shape, -limit, limit))
68

69

70
def keras_init(kernel, bias):
71
  fan_in, fan_out = ts.utils.compute_fan(kernel)
72
  limit = np.sqrt(6 / (fan_in + fan_out))
73
  kernel.assign(tf.random.uniform(kernel.shape, -limit, limit))
74

75
  if bias is not None:
76
    bias.assign(tf.zeros(bias.shape))
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.