pytorch

Форк
0
32 строки · 837.0 Байт
1
# mypy: allow-untyped-defs
2
from contextlib import contextmanager
3

4
import torch
5
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
6

7

8
__all__ = ["is_available", "flags", "set_flags"]
9

10

11
def is_available():
12
    r"""Return whether PyTorch is built with NNPACK support."""
13
    return torch._nnpack_available()
14

15

16
def set_flags(_enabled):
17
    r"""Set if nnpack is enabled globally"""
18
    orig_flags = (torch._C._get_nnpack_enabled(),)
19
    torch._C._set_nnpack_enabled(_enabled)
20
    return orig_flags
21

22

23
@contextmanager
24
def flags(enabled=False):
25
    r"""Context manager for setting if nnpack is enabled globally"""
26
    with __allow_nonbracketed_mutation():
27
        orig_flags = set_flags(enabled)
28
    try:
29
        yield
30
    finally:
31
        with __allow_nonbracketed_mutation():
32
            set_flags(orig_flags[0])
33

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.