pytorch
1# mypy: ignore-errors
2
3import inspect4import itertools5
6from . import _funcs_impl, _reductions_impl7from ._normalizations import normalizer8
9
10# _funcs_impl.py contains functions which mimic NumPy's eponymous equivalents,
11# and consume/return PyTorch tensors/dtypes.
12# They are also type annotated.
13# Pull these functions from _funcs_impl and decorate them with @normalizer, which
14# - Converts any input `np.ndarray`, `torch._numpy.ndarray`, list of lists, Python scalars, etc into a `torch.Tensor`.
15# - Maps NumPy dtypes to PyTorch dtypes
16# - If the input to the `axis` kwarg is an ndarray, it maps it into a tuple
17# - Implements the semantics for the `out=` arg
18# - Wraps back the outputs into `torch._numpy.ndarrays`
19
20
21def _public_functions(mod):22def is_public_function(f):23return inspect.isfunction(f) and not f.__name__.startswith("_")24
25return inspect.getmembers(mod, is_public_function)26
27
28# We fill in __all__ in the loop below
29__all__ = []30
31# decorate implementer functions with argument normalizers and export to the top namespace
32for name, func in itertools.chain(33_public_functions(_funcs_impl), _public_functions(_reductions_impl)34):35if name in ["percentile", "quantile", "median"]:36decorated = normalizer(func, promote_scalar_result=True)37elif name == "einsum":38# normalized manually39decorated = func40else:41decorated = normalizer(func)42
43decorated.__qualname__ = name44decorated.__name__ = name45vars()[name] = decorated46__all__.append(name)47
48
49"""
50Vendored objects from numpy.lib.index_tricks
51"""
52
53
54class IndexExpression:55"""56Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
57last revision: 1999-7-23
58
59Cosmetic changes by T. Oliphant 2001
60"""
61
62def __init__(self, maketuple):63self.maketuple = maketuple64
65def __getitem__(self, item):66if self.maketuple and not isinstance(item, tuple):67return (item,)68else:69return item70
71
72index_exp = IndexExpression(maketuple=True)73s_ = IndexExpression(maketuple=False)74
75
76__all__ += ["index_exp", "s_"]77