pytorch
40 строк · 976.0 Байт
1# mypy: allow-untyped-defs
2import contextlib3from typing import Tuple, Union4
5import torch6from torch._C._functorch import (7get_single_level_autograd_function_allowed,8set_single_level_autograd_function_allowed,9unwrap_if_dead,10)
11from torch.utils._exposed_in import exposed_in12
13
14__all__ = [15"exposed_in",16"argnums_t",17"enable_single_level_autograd_function",18"unwrap_dead_wrappers",19]
20
21
22@contextlib.contextmanager23def enable_single_level_autograd_function():24try:25prev_state = get_single_level_autograd_function_allowed()26set_single_level_autograd_function_allowed(True)27yield28finally:29set_single_level_autograd_function_allowed(prev_state)30
31
32def unwrap_dead_wrappers(args):33# NB: doesn't use tree_map_only for performance reasons34result = tuple(35unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args36)37return result38
39
40argnums_t = Union[int, Tuple[int, ...]]41