pytorch-image-models
149 строк · 4.1 Кб
1""" Model / Layer Config singleton state
2"""
3import os4import warnings5from typing import Any, Optional6
7import torch8
9__all__ = [10'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',11'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'12]
13
14# Set to True if prefer to have layers with no jit optimization (includes activations)
15_NO_JIT = False16
17# Set to True if prefer to have activation layers with no jit optimization
18# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
19# the jit flags so far are activations. This will change as more layers are updated and/or added.
20_NO_ACTIVATION_JIT = False21
22# Set to True if exporting a model with Same padding via ONNX
23_EXPORTABLE = False24
25# Set to True if wanting to use torch.jit.script on a model
26_SCRIPTABLE = False27
28
29# use torch.scaled_dot_product_attention where possible
30_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')31if 'TIMM_FUSED_ATTN' in os.environ:32_USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])33else:34_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)35
36
37def is_no_jit():38return _NO_JIT39
40
41class set_no_jit:42def __init__(self, mode: bool) -> None:43global _NO_JIT44self.prev = _NO_JIT45_NO_JIT = mode46
47def __enter__(self) -> None:48pass49
50def __exit__(self, *args: Any) -> bool:51global _NO_JIT52_NO_JIT = self.prev53return False54
55
56def is_exportable():57return _EXPORTABLE58
59
60class set_exportable:61def __init__(self, mode: bool) -> None:62global _EXPORTABLE63self.prev = _EXPORTABLE64_EXPORTABLE = mode65
66def __enter__(self) -> None:67pass68
69def __exit__(self, *args: Any) -> bool:70global _EXPORTABLE71_EXPORTABLE = self.prev72return False73
74
75def is_scriptable():76return _SCRIPTABLE77
78
79class set_scriptable:80def __init__(self, mode: bool) -> None:81global _SCRIPTABLE82self.prev = _SCRIPTABLE83_SCRIPTABLE = mode84
85def __enter__(self) -> None:86pass87
88def __exit__(self, *args: Any) -> bool:89global _SCRIPTABLE90_SCRIPTABLE = self.prev91return False92
93
94class set_layer_config:95""" Layer config context manager that allows setting all layer config flags at once.96If a flag arg is None, it will not change the current value.
97"""
98def __init__(99self,100scriptable: Optional[bool] = None,101exportable: Optional[bool] = None,102no_jit: Optional[bool] = None,103no_activation_jit: Optional[bool] = None):104global _SCRIPTABLE105global _EXPORTABLE106global _NO_JIT107global _NO_ACTIVATION_JIT108self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT109if scriptable is not None:110_SCRIPTABLE = scriptable111if exportable is not None:112_EXPORTABLE = exportable113if no_jit is not None:114_NO_JIT = no_jit115if no_activation_jit is not None:116_NO_ACTIVATION_JIT = no_activation_jit117
118def __enter__(self) -> None:119pass120
121def __exit__(self, *args: Any) -> bool:122global _SCRIPTABLE123global _EXPORTABLE124global _NO_JIT125global _NO_ACTIVATION_JIT126_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev127return False128
129
130def use_fused_attn(experimental: bool = False) -> bool:131# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0132if not _HAS_FUSED_ATTN or _EXPORTABLE:133return False134if experimental:135return _USE_FUSED_ATTN > 1136return _USE_FUSED_ATTN > 0137
138
139def set_fused_attn(enable: bool = True, experimental: bool = False):140global _USE_FUSED_ATTN141if not _HAS_FUSED_ATTN:142warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.')143return144if experimental and enable:145_USE_FUSED_ATTN = 2146elif enable:147_USE_FUSED_ATTN = 1148else:149_USE_FUSED_ATTN = 0150