pytorch

Форк
0
/
context.py 
128 строк · 3.9 Кб
1
import contextlib
2

3
import functools
4
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
5

6
import torchgen.local as local
7
from torchgen.model import (
8
    BackendIndex,
9
    DispatchKey,
10
    NativeFunction,
11
    NativeFunctionsGroup,
12
    NativeFunctionsViewGroup,
13
)
14
from torchgen.utils import context, S, T
15

16
# Helper functions for defining generators on things in the model
17

18
F = TypeVar(
19
    "F",
20
    NativeFunction,
21
    NativeFunctionsGroup,
22
    NativeFunctionsViewGroup,
23
    Union[NativeFunction, NativeFunctionsGroup],
24
    Union[NativeFunction, NativeFunctionsViewGroup],
25
)
26

27
F2 = TypeVar(
28
    "F2",
29
    NativeFunction,
30
    NativeFunctionsGroup,
31
    Optional[NativeFunction],
32
    bool,
33
    str,
34
)
35

36
F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
37

38

39
@contextlib.contextmanager
40
def native_function_manager(
41
    g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
42
) -> Iterator[None]:
43
    if isinstance(g, NativeFunctionsGroup):
44
        # By default, we associate all errors with structured native functions
45
        # with the out variant.  In some cases, it might be better to have
46
        # a more specific place to hang things; if so, use
47
        # native_function_manager again on the inside
48
        f = g.out
49
    elif isinstance(g, NativeFunctionsViewGroup):
50
        # We associate errors with the view operator
51
        f = g.view
52
    else:
53
        f = g
54
    with context(lambda: f"in native_functions.yaml line {f.loc}:\n  {f.func}"):
55
        with local.parametrize(
56
            use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
57
            use_ilistref_for_tensor_lists=f.part_of_structured_group,
58
        ):
59
            yield
60

61

62
# Given a function that operates on NativeFunction, wrap it into a new function
63
# that sets some appropriate context managers for that native function.
64
# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
65
# (you will get an error if we try to access the local variables without having
66
# set them).
67
def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
68
    @functools.wraps(func)
69
    def wrapper(f: F) -> T:
70
        with native_function_manager(f):
71
            return func(f)
72

73
    return wrapper
74

75

76
def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
77
    @functools.wraps(func)
78
    def wrapper(f: F, f2: F2) -> T:
79
        # The first native_function is assumed to be the one with the appropriate context.
80
        with native_function_manager(f):
81
            return func(f, f2)
82

83
    return wrapper
84

85

86
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
87
    @functools.wraps(func)
88
    def wrapper(slf: S, f: F) -> T:
89
        with native_function_manager(f):
90
            return func(slf, f)
91

92
    return wrapper
93

94

95
def method_with_nested_native_function(
96
    func: Callable[[S, F3], T]
97
) -> Callable[[S, F3], T]:
98
    @functools.wraps(func)
99
    def wrapper(slf: S, f: F3) -> T:
100
        with native_function_manager(f[0]):
101
            return func(slf, f)
102

103
    return wrapper
104

105

106
# Convenience decorator for functions that explicitly take in a BackendIndex,
107
# instead of indirectly taking one in as a closure
108
def with_native_function_and_index(
109
    func: Callable[[F, BackendIndex], T]
110
) -> Callable[[F, BackendIndex], T]:
111
    @functools.wraps(func)
112
    def wrapper(f: F, backend_index: BackendIndex) -> T:
113
        with native_function_manager(f):
114
            return func(f, backend_index)
115

116
    return wrapper
117

118

119
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
120
def with_native_function_and_indices(
121
    func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
122
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
123
    @functools.wraps(func)
124
    def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
125
        with native_function_manager(f):
126
            return func(f, backend_indices)
127

128
    return wrapper
129

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.