3
"""Wrapper to mimic (parts of) np.random API surface.
5
NumPy has strict guarantees on reproducibility etc; here we don't give any.
7
Q: default dtype is float64 in numpy
10
from __future__ import annotations
14
from typing import Optional
18
from . import _dtypes_impl, _util
19
from ._normalizations import array_or_scalar, ArrayLike, normalizer
37
def use_numpy_random():
39
import torch._dynamo.config as config
41
return config.use_numpy_random_stream
45
@functools.wraps(func)
46
def inner(*args, **kwds):
47
if not use_numpy_random():
48
return func(*args, **kwds)
52
from ._ndarray import ndarray
54
f = getattr(numpy.random, func.__name__)
58
arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args
61
key: val.tensor.numpy() if isinstance(val, ndarray) else val
62
for key, val in kwds.items()
65
value = f(*args, **kwds)
68
if isinstance(value, numpy.ndarray):
69
value = ndarray(torch.as_tensor(value))
79
torch.random.manual_seed(seed)
83
def random_sample(size=None):
86
dtype = _dtypes_impl.default_dtypes().float_dtype
87
values = torch.empty(size, dtype=dtype).uniform_()
88
return array_or_scalar(values, return_scalar=size == ())
94
return random_sample(size)
102
def uniform(low=0.0, high=1.0, size=None):
105
dtype = _dtypes_impl.default_dtypes().float_dtype
106
values = torch.empty(size, dtype=dtype).uniform_(low, high)
107
return array_or_scalar(values, return_scalar=size == ())
112
dtype = _dtypes_impl.default_dtypes().float_dtype
113
values = torch.randn(size, dtype=dtype)
114
return array_or_scalar(values, return_scalar=size == ())
118
def normal(loc=0.0, scale=1.0, size=None):
121
dtype = _dtypes_impl.default_dtypes().float_dtype
122
values = torch.empty(size, dtype=dtype).normal_(loc, scale)
123
return array_or_scalar(values, return_scalar=size == ())
129
from ._ndarray import ndarray
131
if isinstance(x, torch.Tensor):
133
elif isinstance(x, ndarray):
136
raise NotImplementedError("We do not random.shuffle lists in-place")
138
perm = torch.randperm(tensor.shape[0])
144
def randint(low, high=None, size=None):
147
if not isinstance(size, (tuple, list)):
151
values = torch.randint(low, high, size=size)
152
return array_or_scalar(values, int, return_scalar=size == ())
157
def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None):
167
elif _util.is_sequence(size):
176
p = torch.ones_like(a) / a.shape[0]
179
atol = sqrt(torch.finfo(p.dtype).eps)
180
if abs(p.sum() - 1.0) > atol:
181
raise ValueError("probabilities do not sum to 1.")
184
indices = torch.multinomial(p, num_el, replacement=replace)
186
if _util.is_sequence(size):
187
indices = indices.reshape(size)