4
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
6
import torchgen.local as local
7
from torchgen.model import (
12
NativeFunctionsViewGroup,
14
from torchgen.utils import context, S, T
22
NativeFunctionsViewGroup,
23
Union[NativeFunction, NativeFunctionsGroup],
24
Union[NativeFunction, NativeFunctionsViewGroup],
31
Optional[NativeFunction],
36
F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
39
@contextlib.contextmanager
40
def native_function_manager(
41
g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
43
if isinstance(g, NativeFunctionsGroup):
49
elif isinstance(g, NativeFunctionsViewGroup):
54
with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
55
with local.parametrize(
56
use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
57
use_ilistref_for_tensor_lists=f.part_of_structured_group,
67
def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
68
@functools.wraps(func)
69
def wrapper(f: F) -> T:
70
with native_function_manager(f):
76
def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
77
@functools.wraps(func)
78
def wrapper(f: F, f2: F2) -> T:
80
with native_function_manager(f):
86
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
87
@functools.wraps(func)
88
def wrapper(slf: S, f: F) -> T:
89
with native_function_manager(f):
95
def method_with_nested_native_function(
96
func: Callable[[S, F3], T]
97
) -> Callable[[S, F3], T]:
98
@functools.wraps(func)
99
def wrapper(slf: S, f: F3) -> T:
100
with native_function_manager(f[0]):
108
def with_native_function_and_index(
109
func: Callable[[F, BackendIndex], T]
110
) -> Callable[[F, BackendIndex], T]:
111
@functools.wraps(func)
112
def wrapper(f: F, backend_index: BackendIndex) -> T:
113
with native_function_manager(f):
114
return func(f, backend_index)
120
def with_native_function_and_indices(
121
func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
122
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
123
@functools.wraps(func)
124
def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
125
with native_function_manager(f):
126
return func(f, backend_indices)