pytorch

Форк
0
/
_funcs.py 
76 строк · 2.0 Кб
1
# mypy: ignore-errors
2

3
import inspect
4
import itertools
5

6
from . import _funcs_impl, _reductions_impl
7
from ._normalizations import normalizer
8

9

10
# _funcs_impl.py contains functions which mimic NumPy's eponymous equivalents,
11
# and consume/return PyTorch tensors/dtypes.
12
# They are also type annotated.
13
# Pull these functions from _funcs_impl and decorate them with @normalizer, which
14
# - Converts any input `np.ndarray`, `torch._numpy.ndarray`, list of lists, Python scalars, etc into a `torch.Tensor`.
15
# - Maps NumPy dtypes to PyTorch dtypes
16
# - If the input to the `axis` kwarg is an ndarray, it maps it into a tuple
17
# - Implements the semantics for the `out=` arg
18
# - Wraps back the outputs into `torch._numpy.ndarrays`
19

20

21
def _public_functions(mod):
22
    def is_public_function(f):
23
        return inspect.isfunction(f) and not f.__name__.startswith("_")
24

25
    return inspect.getmembers(mod, is_public_function)
26

27

28
# We fill in __all__ in the loop below
29
__all__ = []
30

31
# decorate implementer functions with argument normalizers and export to the top namespace
32
for name, func in itertools.chain(
33
    _public_functions(_funcs_impl), _public_functions(_reductions_impl)
34
):
35
    if name in ["percentile", "quantile", "median"]:
36
        decorated = normalizer(func, promote_scalar_result=True)
37
    elif name == "einsum":
38
        # normalized manually
39
        decorated = func
40
    else:
41
        decorated = normalizer(func)
42

43
    decorated.__qualname__ = name
44
    decorated.__name__ = name
45
    vars()[name] = decorated
46
    __all__.append(name)
47

48

49
"""
50
Vendored objects from numpy.lib.index_tricks
51
"""
52

53

54
class IndexExpression:
55
    """
56
    Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
57
    last revision: 1999-7-23
58

59
    Cosmetic changes by T. Oliphant 2001
60
    """
61

62
    def __init__(self, maketuple):
63
        self.maketuple = maketuple
64

65
    def __getitem__(self, item):
66
        if self.maketuple and not isinstance(item, tuple):
67
            return (item,)
68
        else:
69
            return item
70

71

72
index_exp = IndexExpression(maketuple=True)
73
s_ = IndexExpression(maketuple=False)
74

75

76
__all__ += ["index_exp", "s_"]
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.