pytorch

Форк
0
/
random.py 
191 строка · 4.5 Кб
1
# mypy: ignore-errors
2

3
"""Wrapper to mimic (parts of) np.random API surface.
4

5
NumPy has strict guarantees on reproducibility etc; here we don't give any.
6

7
Q: default dtype is float64 in numpy
8

9
"""
10
from __future__ import annotations
11

12
import functools
13
from math import sqrt
14
from typing import Optional
15

16
import torch
17

18
from . import _dtypes_impl, _util
19
from ._normalizations import array_or_scalar, ArrayLike, normalizer
20

21

22
__all__ = [
23
    "seed",
24
    "random_sample",
25
    "sample",
26
    "random",
27
    "rand",
28
    "randn",
29
    "normal",
30
    "choice",
31
    "randint",
32
    "shuffle",
33
    "uniform",
34
]
35

36

37
def use_numpy_random():
38
    # local import to avoid ref cycles
39
    import torch._dynamo.config as config
40

41
    return config.use_numpy_random_stream
42

43

44
def deco_stream(func):
45
    @functools.wraps(func)
46
    def inner(*args, **kwds):
47
        if not use_numpy_random():
48
            return func(*args, **kwds)
49
        else:
50
            import numpy
51

52
            from ._ndarray import ndarray
53

54
            f = getattr(numpy.random, func.__name__)
55

56
            # numpy funcs accept numpy ndarrays, unwrap
57
            args = tuple(
58
                arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args
59
            )
60
            kwds = {
61
                key: val.tensor.numpy() if isinstance(val, ndarray) else val
62
                for key, val in kwds.items()
63
            }
64

65
            value = f(*args, **kwds)
66

67
            # `value` can be either numpy.ndarray or python scalar (or None)
68
            if isinstance(value, numpy.ndarray):
69
                value = ndarray(torch.as_tensor(value))
70

71
            return value
72

73
    return inner
74

75

76
@deco_stream
77
def seed(seed=None):
78
    if seed is not None:
79
        torch.random.manual_seed(seed)
80

81

82
@deco_stream
83
def random_sample(size=None):
84
    if size is None:
85
        size = ()
86
    dtype = _dtypes_impl.default_dtypes().float_dtype
87
    values = torch.empty(size, dtype=dtype).uniform_()
88
    return array_or_scalar(values, return_scalar=size == ())
89

90

91
def rand(*size):
92
    if size == ():
93
        size = None
94
    return random_sample(size)
95

96

97
sample = random_sample
98
random = random_sample
99

100

101
@deco_stream
102
def uniform(low=0.0, high=1.0, size=None):
103
    if size is None:
104
        size = ()
105
    dtype = _dtypes_impl.default_dtypes().float_dtype
106
    values = torch.empty(size, dtype=dtype).uniform_(low, high)
107
    return array_or_scalar(values, return_scalar=size == ())
108

109

110
@deco_stream
111
def randn(*size):
112
    dtype = _dtypes_impl.default_dtypes().float_dtype
113
    values = torch.randn(size, dtype=dtype)
114
    return array_or_scalar(values, return_scalar=size == ())
115

116

117
@deco_stream
118
def normal(loc=0.0, scale=1.0, size=None):
119
    if size is None:
120
        size = ()
121
    dtype = _dtypes_impl.default_dtypes().float_dtype
122
    values = torch.empty(size, dtype=dtype).normal_(loc, scale)
123
    return array_or_scalar(values, return_scalar=size == ())
124

125

126
@deco_stream
127
def shuffle(x):
128
    # no @normalizer because we do not cast e.g. lists to tensors
129
    from ._ndarray import ndarray
130

131
    if isinstance(x, torch.Tensor):
132
        tensor = x
133
    elif isinstance(x, ndarray):
134
        tensor = x.tensor
135
    else:
136
        raise NotImplementedError("We do not random.shuffle lists in-place")
137

138
    perm = torch.randperm(tensor.shape[0])
139
    xp = tensor[perm]
140
    tensor.copy_(xp)
141

142

143
@deco_stream
144
def randint(low, high=None, size=None):
145
    if size is None:
146
        size = ()
147
    if not isinstance(size, (tuple, list)):
148
        size = (size,)
149
    if high is None:
150
        low, high = 0, low
151
    values = torch.randint(low, high, size=size)
152
    return array_or_scalar(values, int, return_scalar=size == ())
153

154

155
@deco_stream
156
@normalizer
157
def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None):
158
    # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
159
    if a.numel() == 1:
160
        a = torch.arange(a)
161

162
    # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises
163

164
    # number of draws
165
    if size is None:
166
        num_el = 1
167
    elif _util.is_sequence(size):
168
        num_el = 1
169
        for el in size:
170
            num_el *= el
171
    else:
172
        num_el = size
173

174
    # prepare the probabilities
175
    if p is None:
176
        p = torch.ones_like(a) / a.shape[0]
177

178
    # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973
179
    atol = sqrt(torch.finfo(p.dtype).eps)
180
    if abs(p.sum() - 1.0) > atol:
181
        raise ValueError("probabilities do not sum to 1.")
182

183
    # actually sample
184
    indices = torch.multinomial(p, num_el, replacement=replace)
185

186
    if _util.is_sequence(size):
187
        indices = indices.reshape(size)
188

189
    samples = a[indices]
190

191
    return samples
192

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.