google-research
46 строк · 1.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implementations of different initialization methods."""
17
18import numpy as np
19import tensorflow.compat.v1 as tf
20
21
22def uniform(shape, scale=0.05, name=None):
23"""Uniform init."""
24initial = tf.random_uniform(
25shape, minval=-scale, maxval=scale, dtype=tf.float32)
26return tf.Variable(initial, name=name)
27
28
29def glorot(shape, name=None):
30"""Glorot & Bengio (AISTATS 2010) init."""
31init_range = np.sqrt(6.0 / (shape[0] + shape[1]))
32initial = tf.random_uniform(
33shape, minval=-init_range, maxval=init_range, dtype=tf.float32)
34return tf.Variable(initial, name=name)
35
36
37def zeros(shape, name=None):
38"""All zeros."""
39initial = tf.zeros(shape, dtype=tf.float32)
40return tf.Variable(initial, name=name)
41
42
43def ones(shape, name=None):
44"""All ones."""
45initial = tf.ones(shape, dtype=tf.float32)
46return tf.Variable(initial, name=name)
47