pytorch

Форк
0
/
random.py 
179 строк · 5.1 Кб
1
from typing import Iterable, List, Union
2

3
import torch
4
from .. import Tensor
5
from . import _lazy_call, _lazy_init, current_device, device_count
6

7
__all__ = [
8
    "get_rng_state",
9
    "get_rng_state_all",
10
    "set_rng_state",
11
    "set_rng_state_all",
12
    "manual_seed",
13
    "manual_seed_all",
14
    "seed",
15
    "seed_all",
16
    "initial_seed",
17
]
18

19

20
def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
21
    r"""Return the random number generator state of the specified GPU as a ByteTensor.
22

23
    Args:
24
        device (torch.device or int, optional): The device to return the RNG state of.
25
            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
26

27
    .. warning::
28
        This function eagerly initializes CUDA.
29
    """
30
    _lazy_init()
31
    if isinstance(device, str):
32
        device = torch.device(device)
33
    elif isinstance(device, int):
34
        device = torch.device("cuda", device)
35
    idx = device.index
36
    if idx is None:
37
        idx = current_device()
38
    default_generator = torch.cuda.default_generators[idx]
39
    return default_generator.get_state()
40

41

42
def get_rng_state_all() -> List[Tensor]:
43
    r"""Return a list of ByteTensor representing the random number states of all devices."""
44
    results = []
45
    for i in range(device_count()):
46
        results.append(get_rng_state(i))
47
    return results
48

49

50
def set_rng_state(
51
    new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
52
) -> None:
53
    r"""Set the random number generator state of the specified GPU.
54

55
    Args:
56
        new_state (torch.ByteTensor): The desired state
57
        device (torch.device or int, optional): The device to set the RNG state.
58
            Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
59
    """
60
    with torch._C._DisableFuncTorch():
61
        new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
62
    if isinstance(device, str):
63
        device = torch.device(device)
64
    elif isinstance(device, int):
65
        device = torch.device("cuda", device)
66

67
    def cb():
68
        idx = device.index
69
        if idx is None:
70
            idx = current_device()
71
        default_generator = torch.cuda.default_generators[idx]
72
        default_generator.set_state(new_state_copy)
73

74
    _lazy_call(cb)
75

76

77
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
78
    r"""Set the random number generator state of all devices.
79

80
    Args:
81
        new_states (Iterable of torch.ByteTensor): The desired state for each device.
82
    """
83
    for i, state in enumerate(new_states):
84
        set_rng_state(state, i)
85

86

87
def manual_seed(seed: int) -> None:
88
    r"""Set the seed for generating random numbers for the current GPU.
89

90
    It's safe to call this function if CUDA is not available; in that
91
    case, it is silently ignored.
92

93
    Args:
94
        seed (int): The desired seed.
95

96
    .. warning::
97
        If you are working with a multi-GPU model, this function is insufficient
98
        to get determinism.  To seed all GPUs, use :func:`manual_seed_all`.
99
    """
100
    seed = int(seed)
101

102
    def cb():
103
        idx = current_device()
104
        default_generator = torch.cuda.default_generators[idx]
105
        default_generator.manual_seed(seed)
106

107
    _lazy_call(cb, seed=True)
108

109

110
def manual_seed_all(seed: int) -> None:
111
    r"""Set the seed for generating random numbers on all GPUs.
112

113
    It's safe to call this function if CUDA is not available; in that
114
    case, it is silently ignored.
115

116
    Args:
117
        seed (int): The desired seed.
118
    """
119
    seed = int(seed)
120

121
    def cb():
122
        for i in range(device_count()):
123
            default_generator = torch.cuda.default_generators[i]
124
            default_generator.manual_seed(seed)
125

126
    _lazy_call(cb, seed_all=True)
127

128

129
def seed() -> None:
130
    r"""Set the seed for generating random numbers to a random number for the current GPU.
131

132
    It's safe to call this function if CUDA is not available; in that
133
    case, it is silently ignored.
134

135
    .. warning::
136
        If you are working with a multi-GPU model, this function will only initialize
137
        the seed on one GPU.  To initialize all GPUs, use :func:`seed_all`.
138
    """
139

140
    def cb():
141
        idx = current_device()
142
        default_generator = torch.cuda.default_generators[idx]
143
        default_generator.seed()
144

145
    _lazy_call(cb)
146

147

148
def seed_all() -> None:
149
    r"""Set the seed for generating random numbers to a random number on all GPUs.
150

151
    It's safe to call this function if CUDA is not available; in that
152
    case, it is silently ignored.
153
    """
154

155
    def cb():
156
        random_seed = 0
157
        seeded = False
158
        for i in range(device_count()):
159
            default_generator = torch.cuda.default_generators[i]
160
            if not seeded:
161
                default_generator.seed()
162
                random_seed = default_generator.initial_seed()
163
                seeded = True
164
            else:
165
                default_generator.manual_seed(random_seed)
166

167
    _lazy_call(cb)
168

169

170
def initial_seed() -> int:
171
    r"""Return the current random seed of the current GPU.
172

173
    .. warning::
174
        This function eagerly initializes CUDA.
175
    """
176
    _lazy_init()
177
    idx = current_device()
178
    default_generator = torch.cuda.default_generators[idx]
179
    return default_generator.initial_seed()
180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.