pytorch

Форк
0
/
context.py 
130 строк · 3.9 Кб
1
from __future__ import annotations
2

3
import contextlib
4
import functools
5
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
6

7
import torchgen.local as local
8
from torchgen.model import (
9
    BackendIndex,
10
    DispatchKey,
11
    NativeFunction,
12
    NativeFunctionsGroup,
13
    NativeFunctionsViewGroup,
14
)
15
from torchgen.utils import context, S, T
16

17

18
# Helper functions for defining generators on things in the model
19

20
F = TypeVar(
21
    "F",
22
    NativeFunction,
23
    NativeFunctionsGroup,
24
    NativeFunctionsViewGroup,
25
    Union[NativeFunction, NativeFunctionsGroup],
26
    Union[NativeFunction, NativeFunctionsViewGroup],
27
)
28

29
F2 = TypeVar(
30
    "F2",
31
    NativeFunction,
32
    NativeFunctionsGroup,
33
    Optional[NativeFunction],
34
    bool,
35
    str,
36
)
37

38
F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
39

40

41
@contextlib.contextmanager
42
def native_function_manager(
43
    g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
44
) -> Iterator[None]:
45
    if isinstance(g, NativeFunctionsGroup):
46
        # By default, we associate all errors with structured native functions
47
        # with the out variant.  In some cases, it might be better to have
48
        # a more specific place to hang things; if so, use
49
        # native_function_manager again on the inside
50
        f = g.out
51
    elif isinstance(g, NativeFunctionsViewGroup):
52
        # We associate errors with the view operator
53
        f = g.view
54
    else:
55
        f = g
56
    with context(lambda: f"in native_functions.yaml line {f.loc}:\n  {f.func}"):
57
        with local.parametrize(
58
            use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
59
            use_ilistref_for_tensor_lists=f.part_of_structured_group,
60
        ):
61
            yield
62

63

64
# Given a function that operates on NativeFunction, wrap it into a new function
65
# that sets some appropriate context managers for that native function.
66
# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
67
# (you will get an error if we try to access the local variables without having
68
# set them).
69
def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
70
    @functools.wraps(func)
71
    def wrapper(f: F) -> T:
72
        with native_function_manager(f):
73
            return func(f)
74

75
    return wrapper
76

77

78
def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
79
    @functools.wraps(func)
80
    def wrapper(f: F, f2: F2) -> T:
81
        # The first native_function is assumed to be the one with the appropriate context.
82
        with native_function_manager(f):
83
            return func(f, f2)
84

85
    return wrapper
86

87

88
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
89
    @functools.wraps(func)
90
    def wrapper(slf: S, f: F) -> T:
91
        with native_function_manager(f):
92
            return func(slf, f)
93

94
    return wrapper
95

96

97
def method_with_nested_native_function(
98
    func: Callable[[S, F3], T]
99
) -> Callable[[S, F3], T]:
100
    @functools.wraps(func)
101
    def wrapper(slf: S, f: F3) -> T:
102
        with native_function_manager(f[0]):
103
            return func(slf, f)
104

105
    return wrapper
106

107

108
# Convenience decorator for functions that explicitly take in a BackendIndex,
109
# instead of indirectly taking one in as a closure
110
def with_native_function_and_index(
111
    func: Callable[[F, BackendIndex], T]
112
) -> Callable[[F, BackendIndex], T]:
113
    @functools.wraps(func)
114
    def wrapper(f: F, backend_index: BackendIndex) -> T:
115
        with native_function_manager(f):
116
            return func(f, backend_index)
117

118
    return wrapper
119

120

121
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
122
def with_native_function_and_indices(
123
    func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
124
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
125
    @functools.wraps(func)
126
    def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
127
        with native_function_manager(f):
128
            return func(f, backend_indices)
129

130
    return wrapper
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.