pytorch

Форк
0
/
init.py 
53 строки · 2.2 Кб
1
import inspect
2
import torch
3

4

5
def skip_init(module_cls, *args, **kwargs):
6
    r"""
7
    Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers.
8

9
    This can be useful if initialization is slow or if custom initialization will
10
    be performed, making the default initialization unnecessary. There are some caveats to this, due to
11
    the way this function is implemented:
12

13
    1. The module must accept a `device` arg in its constructor that is passed to any parameters
14
    or buffers created during construction.
15

16
    2. The module must not perform any computation on parameters in its constructor except
17
    initialization (i.e. functions from :mod:`torch.nn.init`).
18

19
    If these conditions are satisfied, the module can be instantiated with parameter / buffer values
20
    uninitialized, as if having been created using :func:`torch.empty`.
21

22
    Args:
23
        module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
24
        args: args to pass to the module's constructor
25
        kwargs: kwargs to pass to the module's constructor
26

27
    Returns:
28
        Instantiated module with uninitialized parameters / buffers
29

30
    Example::
31

32
        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
33
        >>> import torch
34
        >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
35
        >>> m.weight
36
        Parameter containing:
37
        tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
38
               requires_grad=True)
39
        >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
40
        >>> m2.weight
41
        Parameter containing:
42
        tensor([[-1.4677e+24,  4.5915e-41,  1.4013e-45,  0.0000e+00, -1.4677e+24,
43
                  4.5915e-41]], requires_grad=True)
44

45
    """
46
    if not issubclass(module_cls, torch.nn.Module):
47
        raise RuntimeError(f'Expected a Module; got {module_cls}')
48
    if 'device' not in inspect.signature(module_cls).parameters:
49
        raise RuntimeError('Module must support a \'device\' arg to skip initialization')
50

51
    final_device = kwargs.pop('device', 'cpu')
52
    kwargs['device'] = 'meta'
53
    return module_cls(*args, **kwargs).to_empty(device=final_device)
54

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.