1
"""This file contains utilities for initializing neural network parameters."""
5
from torch import Tensor
7
from typing import Optional as _Optional
9
# These no_grad_* functions are necessary as wrappers around the parts of these
10
# functions that use `with torch.no_grad()`. The JIT doesn't support context
11
# managers, so these need to be implemented as builtins. Using these wrappers
12
# lets us keep those builtins small and re-usable.
13
def _no_grad_uniform_(tensor, a, b, generator=None):
15
return tensor.uniform_(a, b, generator=generator)
18
def _no_grad_normal_(tensor, mean, std, generator=None):
20
return tensor.normal_(mean, std, generator=generator)
23
def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
24
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
26
# Computes standard normal cumulative distribution function
27
return (1. + math.erf(x / math.sqrt(2.))) / 2.
29
if (mean < a - 2 * std) or (mean > b + 2 * std):
30
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
31
"The distribution of values may be incorrect.",
35
# Values are generated by using a truncated uniform distribution and
36
# then using the inverse CDF for the normal distribution.
37
# Get upper and lower cdf values
38
l = norm_cdf((a - mean) / std)
39
u = norm_cdf((b - mean) / std)
41
# Uniformly fill tensor with values from [l, u], then translate to
43
tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
45
# Use inverse cdf transform for normal distribution to get truncated
49
# Transform to proper mean, std
50
tensor.mul_(std * math.sqrt(2.))
53
# Clamp to ensure it's in the proper range
54
tensor.clamp_(min=a, max=b)
58
def _no_grad_fill_(tensor, val):
60
return tensor.fill_(val)
63
def _no_grad_zero_(tensor):
68
def calculate_gain(nonlinearity, param=None):
69
r"""Return the recommended gain value for the given nonlinearity function.
71
The values are as follows:
73
================= ====================================================
75
================= ====================================================
76
Linear / Identity :math:`1`
77
Conv{1,2,3}D :math:`1`
79
Tanh :math:`\frac{5}{3}`
81
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
82
SELU :math:`\frac{3}{4}`
83
================= ====================================================
86
In order to implement `Self-Normalizing Neural Networks`_ ,
87
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
88
This gives the initial weights a variance of ``1 / N``,
89
which is necessary to induce a stable fixed point in the forward pass.
90
In contrast, the default gain for ``SELU`` sacrifices the normalization
91
effect for more stable gradient flow in rectangular layers.
94
nonlinearity: the non-linear function (`nn.functional` name)
95
param: optional parameter for the non-linear function
98
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
100
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
102
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
103
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
105
elif nonlinearity == 'tanh':
107
elif nonlinearity == 'relu':
108
return math.sqrt(2.0)
109
elif nonlinearity == 'leaky_relu':
111
negative_slope = 0.01
112
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
113
# True/False are instances of int, hence check above
114
negative_slope = param
116
raise ValueError(f"negative_slope {param} not a valid number")
117
return math.sqrt(2.0 / (1 + negative_slope ** 2))
118
elif nonlinearity == 'selu':
119
return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
121
raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
128
generator: _Optional[torch.Generator] = None,
130
r"""Fill the input Tensor with values drawn from the uniform distribution.
132
:math:`\mathcal{U}(a, b)`.
135
tensor: an n-dimensional `torch.Tensor`
136
a: the lower bound of the uniform distribution
137
b: the upper bound of the uniform distribution
138
generator: the torch Generator to sample from (default: None)
141
>>> w = torch.empty(3, 5)
142
>>> nn.init.uniform_(w)
144
if torch.overrides.has_torch_function_variadic(tensor):
145
return torch.overrides.handle_torch_function(
146
uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
148
return _no_grad_uniform_(tensor, a, b, generator)
155
generator: _Optional[torch.Generator] = None,
157
r"""Fill the input Tensor with values drawn from the normal distribution.
159
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
162
tensor: an n-dimensional `torch.Tensor`
163
mean: the mean of the normal distribution
164
std: the standard deviation of the normal distribution
165
generator: the torch Generator to sample from (default: None)
168
>>> w = torch.empty(3, 5)
169
>>> nn.init.normal_(w)
171
if torch.overrides.has_torch_function_variadic(tensor):
172
return torch.overrides.handle_torch_function(
173
normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
175
return _no_grad_normal_(tensor, mean, std, generator)
183
generator: _Optional[torch.Generator] = None
185
r"""Fill the input Tensor with values drawn from a truncated normal distribution.
187
The values are effectively drawn from the
188
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
189
with values outside :math:`[a, b]` redrawn until they are within
190
the bounds. The method used for generating the random values works
191
best when :math:`a \leq \text{mean} \leq b`.
194
tensor: an n-dimensional `torch.Tensor`
195
mean: the mean of the normal distribution
196
std: the standard deviation of the normal distribution
197
a: the minimum cutoff value
198
b: the maximum cutoff value
199
generator: the torch Generator to sample from (default: None)
202
>>> w = torch.empty(3, 5)
203
>>> nn.init.trunc_normal_(w)
205
return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
208
def constant_(tensor: Tensor, val: float) -> Tensor:
209
r"""Fill the input Tensor with the value :math:`\text{val}`.
212
tensor: an n-dimensional `torch.Tensor`
213
val: the value to fill the tensor with
216
>>> w = torch.empty(3, 5)
217
>>> nn.init.constant_(w, 0.3)
219
if torch.overrides.has_torch_function_variadic(tensor):
220
return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
221
return _no_grad_fill_(tensor, val)
224
def ones_(tensor: Tensor) -> Tensor:
225
r"""Fill the input Tensor with the scalar value `1`.
228
tensor: an n-dimensional `torch.Tensor`
231
>>> w = torch.empty(3, 5)
234
return _no_grad_fill_(tensor, 1.)
237
def zeros_(tensor: Tensor) -> Tensor:
238
r"""Fill the input Tensor with the scalar value `0`.
241
tensor: an n-dimensional `torch.Tensor`
244
>>> w = torch.empty(3, 5)
245
>>> nn.init.zeros_(w)
247
return _no_grad_zero_(tensor)
251
r"""Fill the 2-dimensional input `Tensor` with the identity matrix.
253
Preserves the identity of the inputs in `Linear` layers, where as
254
many inputs are preserved as possible.
257
tensor: a 2-dimensional `torch.Tensor`
260
>>> w = torch.empty(3, 5)
263
if tensor.ndimension() != 2:
264
raise ValueError("Only tensors with 2 dimensions are supported")
266
with torch.no_grad():
267
torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
271
def dirac_(tensor, groups=1):
272
r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function.
274
Preserves the identity of the inputs in `Convolutional`
275
layers, where as many input channels are preserved as possible. In case
276
of groups>1, each group of channels preserves identity
279
tensor: a {3, 4, 5}-dimensional `torch.Tensor`
280
groups (int, optional): number of groups in the conv layer (default: 1)
282
>>> w = torch.empty(3, 16, 5, 5)
283
>>> nn.init.dirac_(w)
284
>>> w = torch.empty(3, 24, 5, 5)
285
>>> nn.init.dirac_(w, 3)
287
dimensions = tensor.ndimension()
288
if dimensions not in [3, 4, 5]:
289
raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
291
sizes = tensor.size()
293
if sizes[0] % groups != 0:
294
raise ValueError('dim 0 must be divisible by groups')
296
out_chans_per_grp = sizes[0] // groups
297
min_dim = min(out_chans_per_grp, sizes[1])
299
with torch.no_grad():
302
for g in range(groups):
303
for d in range(min_dim):
304
if dimensions == 3: # Temporal convolution
305
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
306
elif dimensions == 4: # Spatial convolution
307
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
308
tensor.size(3) // 2] = 1
309
else: # Volumetric convolution
310
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
311
tensor.size(3) // 2, tensor.size(4) // 2] = 1
315
def _calculate_fan_in_and_fan_out(tensor):
316
dimensions = tensor.dim()
318
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
320
num_input_fmaps = tensor.size(1)
321
num_output_fmaps = tensor.size(0)
322
receptive_field_size = 1
324
# math.prod is not always available, accumulate the product manually
325
# we could use functools.reduce but that is not supported by TorchScript
326
for s in tensor.shape[2:]:
327
receptive_field_size *= s
328
fan_in = num_input_fmaps * receptive_field_size
329
fan_out = num_output_fmaps * receptive_field_size
331
return fan_in, fan_out
335
tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None
337
r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
339
The method is described in `Understanding the difficulty of training
340
deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
341
The resulting tensor will have values sampled from
342
:math:`\mathcal{U}(-a, a)` where
345
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
347
Also known as Glorot initialization.
350
tensor: an n-dimensional `torch.Tensor`
351
gain: an optional scaling factor
352
generator: the torch Generator to sample from (default: None)
355
>>> w = torch.empty(3, 5)
356
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
358
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
359
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
360
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
362
return _no_grad_uniform_(tensor, -a, a, generator)
368
generator: _Optional[torch.Generator] = None,
370
r"""Fill the input `Tensor` with values using a Xavier normal distribution.
372
The method is described in `Understanding the difficulty of training deep feedforward
373
neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
374
will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
377
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
379
Also known as Glorot initialization.
382
tensor: an n-dimensional `torch.Tensor`
383
gain: an optional scaling factor
384
generator: the torch Generator to sample from (default: None)
387
>>> w = torch.empty(3, 5)
388
>>> nn.init.xavier_normal_(w)
390
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
391
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
393
return _no_grad_normal_(tensor, 0., std, generator)
396
def _calculate_correct_fan(tensor, mode):
398
valid_modes = ['fan_in', 'fan_out']
399
if mode not in valid_modes:
400
raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
402
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
403
return fan_in if mode == 'fan_in' else fan_out
409
mode: str = "fan_in",
410
nonlinearity: str = "leaky_relu",
411
generator: _Optional[torch.Generator] = None,
413
r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
415
The method is described in `Delving deep into rectifiers: Surpassing
416
human-level performance on ImageNet classification` - He, K. et al. (2015).
417
The resulting tensor will have values sampled from
418
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
421
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
423
Also known as He initialization.
426
tensor: an n-dimensional `torch.Tensor`
427
a: the negative slope of the rectifier used after this layer (only
428
used with ``'leaky_relu'``)
429
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
430
preserves the magnitude of the variance of the weights in the
431
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
433
nonlinearity: the non-linear function (`nn.functional` name),
434
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
435
generator: the torch Generator to sample from (default: None)
438
>>> w = torch.empty(3, 5)
439
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
441
if torch.overrides.has_torch_function_variadic(tensor):
442
return torch.overrides.handle_torch_function(
448
nonlinearity=nonlinearity,
451
if 0 in tensor.shape:
452
warnings.warn("Initializing zero-element tensors is a no-op")
454
fan = _calculate_correct_fan(tensor, mode)
455
gain = calculate_gain(nonlinearity, a)
456
std = gain / math.sqrt(fan)
457
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
458
with torch.no_grad():
459
return tensor.uniform_(-bound, bound, generator=generator)
465
mode: str = "fan_in",
466
nonlinearity: str = "leaky_relu",
467
generator: _Optional[torch.Generator] = None,
469
r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
471
The method is described in `Delving deep into rectifiers: Surpassing
472
human-level performance on ImageNet classification` - He, K. et al. (2015).
473
The resulting tensor will have values sampled from
474
:math:`\mathcal{N}(0, \text{std}^2)` where
477
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
479
Also known as He initialization.
482
tensor: an n-dimensional `torch.Tensor`
483
a: the negative slope of the rectifier used after this layer (only
484
used with ``'leaky_relu'``)
485
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
486
preserves the magnitude of the variance of the weights in the
487
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
489
nonlinearity: the non-linear function (`nn.functional` name),
490
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
491
generator: the torch Generator to sample from (default: None)
494
>>> w = torch.empty(3, 5)
495
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
497
if 0 in tensor.shape:
498
warnings.warn("Initializing zero-element tensors is a no-op")
500
fan = _calculate_correct_fan(tensor, mode)
501
gain = calculate_gain(nonlinearity, a)
502
std = gain / math.sqrt(fan)
503
with torch.no_grad():
504
return tensor.normal_(0, std, generator=generator)
510
generator: _Optional[torch.Generator] = None,
512
r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
514
Described in `Exact solutions to the nonlinear dynamics of learning in deep
515
linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
516
at least 2 dimensions, and for tensors with more than 2 dimensions the
517
trailing dimensions are flattened.
520
tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
521
gain: optional scaling factor
522
generator: the torch Generator to sample from (default: None)
525
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
526
>>> w = torch.empty(3, 5)
527
>>> nn.init.orthogonal_(w)
529
if tensor.ndimension() < 2:
530
raise ValueError("Only tensors with 2 or more dimensions are supported")
532
if tensor.numel() == 0:
535
rows = tensor.size(0)
536
cols = tensor.numel() // rows
537
flattened = tensor.new(rows, cols).normal_(0, 1, generator=generator)
542
# Compute the qr factorization
543
q, r = torch.linalg.qr(flattened)
544
# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
552
with torch.no_grad():
553
tensor.view_as(q).copy_(q)
562
generator: _Optional[torch.Generator] = None,
564
r"""Fill the 2D input `Tensor` as a sparse matrix.
566
The non-zero elements will be drawn from the normal distribution
567
:math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
568
Hessian-free optimization` - Martens, J. (2010).
571
tensor: an n-dimensional `torch.Tensor`
572
sparsity: The fraction of elements in each column to be set to zero
573
std: the standard deviation of the normal distribution used to generate
575
generator: the torch Generator to sample from (default: None)
578
>>> w = torch.empty(3, 5)
579
>>> nn.init.sparse_(w, sparsity=0.1)
581
if tensor.ndimension() != 2:
582
raise ValueError("Only tensors with 2 dimensions are supported")
584
rows, cols = tensor.shape
585
num_zeros = int(math.ceil(sparsity * rows))
587
with torch.no_grad():
588
tensor.normal_(0, std, generator=generator)
589
for col_idx in range(cols):
590
row_indices = torch.randperm(rows)
591
zero_indices = row_indices[:num_zeros]
592
tensor[zero_indices, col_idx] = 0
596
# for backward compatibility
597
def _make_deprecate(meth):
598
new_name = meth.__name__
599
old_name = new_name[:-1]
601
def deprecated_init(*args, **kwargs):
602
warnings.warn(f"nn.init.{old_name} is now deprecated in favor of nn.init.{new_name}.", stacklevel=2)
603
return meth(*args, **kwargs)
605
deprecated_init.__doc__ = fr"""
609
This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
611
See :func:`~torch.nn.init.{new_name}` for details."""
612
deprecated_init.__name__ = old_name
613
return deprecated_init
616
uniform = _make_deprecate(uniform_)
617
normal = _make_deprecate(normal_)
618
constant = _make_deprecate(constant_)
619
eye = _make_deprecate(eye_)
620
dirac = _make_deprecate(dirac_)
621
xavier_uniform = _make_deprecate(xavier_uniform_)
622
xavier_normal = _make_deprecate(xavier_normal_)
623
kaiming_uniform = _make_deprecate(kaiming_uniform_)
624
kaiming_normal = _make_deprecate(kaiming_normal_)
625
orthogonal = _make_deprecate(orthogonal_)
626
sparse = _make_deprecate(sparse_)