pytorch

Форк
0
/
random.py 
175 строк · 6.7 Кб
1
import contextlib
2
from typing import Generator
3
import warnings
4

5
from torch._C import default_generator
6
import torch
7

8

9
def set_rng_state(new_state: torch.Tensor) -> None:
10
    r"""Sets the random number generator state.
11

12
    .. note: This function only works for CPU. For CUDA, please use
13
             torch.manual_seed(seed), which works for both CPU and CUDA.
14

15
    Args:
16
        new_state (torch.ByteTensor): The desired state
17
    """
18
    default_generator.set_state(new_state)
19

20

21
def get_rng_state() -> torch.Tensor:
22
    r"""Returns the random number generator state as a `torch.ByteTensor`."""
23
    return default_generator.get_state()
24

25

26
def manual_seed(seed) -> torch._C.Generator:
27
    r"""Sets the seed for generating random numbers. Returns a
28
    `torch.Generator` object.
29

30
    Args:
31
        seed (int): The desired seed. Value must be within the inclusive range
32
            `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
33
            is raised. Negative inputs are remapped to positive values with the formula
34
            `0xffff_ffff_ffff_ffff + seed`.
35
    """
36
    seed = int(seed)
37
    import torch.cuda
38

39
    if not torch.cuda._is_in_bad_fork():
40
        torch.cuda.manual_seed_all(seed)
41

42
    import torch.mps
43
    if not torch.mps._is_in_bad_fork():
44
        torch.mps.manual_seed(seed)
45

46
    import torch.xpu
47
    if not torch.xpu._is_in_bad_fork():
48
        torch.xpu.manual_seed_all(seed)
49

50
    _seed_custom_device(seed)
51

52
    return default_generator.manual_seed(seed)
53

54

55
def seed() -> int:
56
    r"""Sets the seed for generating random numbers to a non-deterministic
57
    random number. Returns a 64 bit number used to seed the RNG.
58
    """
59
    seed = default_generator.seed()
60
    import torch.cuda
61

62
    if not torch.cuda._is_in_bad_fork():
63
        torch.cuda.manual_seed_all(seed)
64

65
    import torch.mps
66
    if not torch.mps._is_in_bad_fork():
67
        torch.mps.manual_seed(seed)
68

69
    import torch.xpu
70
    if not torch.xpu._is_in_bad_fork():
71
        torch.xpu.manual_seed_all(seed)
72

73
    _seed_custom_device(seed)
74

75
    return seed
76

77

78
def _seed_custom_device(seed) -> None:
79
    r"""Sets the seed to generate random numbers for custom device.
80

81
    Args:
82
        seed (int): The desired seed.
83

84
    See [Note: support the custom device with privateuse1]
85
    """
86
    seed = int(seed)
87
    custom_backend_name = torch._C._get_privateuse1_backend_name()
88
    if hasattr(torch, custom_backend_name):
89
        custom_device_mod = getattr(torch, custom_backend_name)
90
        _bad_fork_name = "_is_in_bad_fork"
91
        _seed_all_name = "manual_seed_all"
92
        if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
93
            if not getattr(custom_device_mod, _bad_fork_name)():
94
                getattr(custom_device_mod, _seed_all_name)(seed)
95
        else:
96
            message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
97
            message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
98
            warnings.warn(message, UserWarning, stacklevel=3)
99

100

101
def initial_seed() -> int:
102
    r"""Returns the initial seed for generating random numbers as a
103
    Python `long`.
104
    """
105
    return default_generator.initial_seed()
106

107

108
_fork_rng_warned_already = False
109

110

111
@contextlib.contextmanager
112
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
113
    """
114
    Forks the RNG, so that when you return, the RNG is reset
115
    to the state that it was previously in.
116

117
    Args:
118
        devices (iterable of Device IDs): devices for which to fork
119
            the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
120
            on all devices, but will emit a warning if your machine has a lot
121
            of devices, since this function will run very slowly in that case.
122
            If you explicitly specify devices, this warning will be suppressed
123
        enabled (bool): if ``False``, the RNG is not forked.  This is a convenience
124
            argument for easily disabling the context manager without having
125
            to delete it and unindent your Python code under it.
126
        deivce_type (str): device type str, default is `cuda`. As for custom device,
127
            see details in [Note: support the custom device with privateuse1]
128
    """
129

130
    device_type = torch.device(device_type).type
131
    device_mod = getattr(torch, device_type, None)
132
    if device_mod is None:
133
        raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
134
                           "a module by `torch._register_device_module`.")
135
    global _fork_rng_warned_already
136

137
    # Internal arguments:
138
    #   _caller: the function which called fork_rng, which the user used
139
    #   _devices_kw: the devices keyword of _caller
140

141
    if not enabled:
142
        yield
143
        return
144

145
    if devices is None:
146
        num_devices = device_mod.device_count()
147
        if num_devices > 1 and not _fork_rng_warned_already:
148
            message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
149
                       f"you have used {_caller} without explicitly specifying which devices are being used. "
150
                       f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
151
                       f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
152
                       f" making use of a few {device_type.upper()} devices, set the environment variable "
153
                       f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
154
                       "with the set of devices you are actually using. For example, if you are using CPU only, "
155
                       "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
156
                       f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0].  To initialize all devices "
157
                       f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
158
                       f"`range(torch.{device_type}.device_count())`.")
159
            warnings.warn(message)
160
            _fork_rng_warned_already = True
161
        devices = list(range(num_devices))
162
    else:
163
        # Protect against user passing us a generator; we need to traverse this
164
        # multiple times but a generator will be exhausted upon first traversal
165
        devices = list(devices)
166

167
    cpu_rng_state = torch.get_rng_state()
168
    device_rng_states = [device_mod.get_rng_state(device) for device in devices]
169

170
    try:
171
        yield
172
    finally:
173
        torch.set_rng_state(cpu_rng_state)
174
        for device, device_rng_state in zip(devices, device_rng_states):
175
            device_mod.set_rng_state(device_rng_state, device)
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.