pytorch

Форк
0
40 строк · 976.0 Байт
1
# mypy: allow-untyped-defs
2
import contextlib
3
from typing import Tuple, Union
4

5
import torch
6
from torch._C._functorch import (
7
    get_single_level_autograd_function_allowed,
8
    set_single_level_autograd_function_allowed,
9
    unwrap_if_dead,
10
)
11
from torch.utils._exposed_in import exposed_in
12

13

14
__all__ = [
15
    "exposed_in",
16
    "argnums_t",
17
    "enable_single_level_autograd_function",
18
    "unwrap_dead_wrappers",
19
]
20

21

22
@contextlib.contextmanager
23
def enable_single_level_autograd_function():
24
    try:
25
        prev_state = get_single_level_autograd_function_allowed()
26
        set_single_level_autograd_function_allowed(True)
27
        yield
28
    finally:
29
        set_single_level_autograd_function_allowed(prev_state)
30

31

32
def unwrap_dead_wrappers(args):
33
    # NB: doesn't use tree_map_only for performance reasons
34
    result = tuple(
35
        unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
36
    )
37
    return result
38

39

40
argnums_t = Union[int, Tuple[int, ...]]
41

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.