pytorch

Форк
0
/
__future__.py 
75 строк · 3.1 Кб
1
_overwrite_module_params_on_conversion: bool = False
2
_swap_module_params_on_conversion: bool = False
3

4

5
def set_overwrite_module_params_on_conversion(value: bool) -> None:
6
    """
7
    Sets whether to assign new tensors to the parameters instead of changing the
8
    existing parameters in-place when converting an ``nn.Module``.
9

10
    When enabled, the following methods will assign new parameters to the module:
11

12
    #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
13
    #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
14
    #. :meth:`nn.Module.to`
15
    #. :meth:`nn.Module.to_empty`
16

17
    Args:
18
        value (bool): Whether to assign new tensors or not.
19

20
    """
21
    global _overwrite_module_params_on_conversion
22
    _overwrite_module_params_on_conversion = value
23

24

25
def get_overwrite_module_params_on_conversion() -> bool:
26
    """
27
    Returns whether to assign new tensors to the parameters instead of changing the
28
    existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``.
29

30
    See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
31
    """
32
    return _overwrite_module_params_on_conversion
33

34

35
def set_swap_module_params_on_conversion(value: bool) -> None:
36
    """
37
    Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
38
    change the existing parameters in-place when converting an ``nn.Module`` and instead
39
    of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
40

41
    .. note::
42
        This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion`
43

44
    When enabled, the following methods will swap the existing parameters in-place:
45

46
    #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
47
    #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
48
    #. :meth:`nn.Module.to`
49
    #. :meth:`nn.Module.to_empty`
50
    #. :meth:`nn.Module.load_state_dict`
51

52
    The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
53

54
    #. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via
55
       :meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
56
    #. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
57
    #. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`
58
       with ``res``
59

60
    Args:
61
        value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not.
62

63
    """
64
    global _swap_module_params_on_conversion
65
    _swap_module_params_on_conversion = value
66

67

68
def get_swap_module_params_on_conversion() -> bool:
69
    """
70
    Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
71
    change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
72

73
    See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information.
74
    """
75
    return _swap_module_params_on_conversion
76

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.