pytorch
1# mypy: allow-untyped-defs
2from contextlib import contextmanager3
4import torch5from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule6
7
8__all__ = ["is_available", "flags", "set_flags"]9
10
11def is_available():12r"""Return whether PyTorch is built with NNPACK support."""13return torch._nnpack_available()14
15
16def set_flags(_enabled):17r"""Set if nnpack is enabled globally"""18orig_flags = (torch._C._get_nnpack_enabled(),)19torch._C._set_nnpack_enabled(_enabled)20return orig_flags21
22
23@contextmanager
24def flags(enabled=False):25r"""Context manager for setting if nnpack is enabled globally"""26with __allow_nonbracketed_mutation():27orig_flags = set_flags(enabled)28try:29yield30finally:31with __allow_nonbracketed_mutation():32set_flags(orig_flags[0])33