pytorch-image-models

Форк
0
149 строк · 4.1 Кб
1
""" Model / Layer Config singleton state
2
"""
3
import os
4
import warnings
5
from typing import Any, Optional
6

7
import torch
8

9
__all__ = [
10
    'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
11
    'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
12
]
13

14
# Set to True if prefer to have layers with no jit optimization (includes activations)
15
_NO_JIT = False
16

17
# Set to True if prefer to have activation layers with no jit optimization
18
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
19
# the jit flags so far are activations. This will change as more layers are updated and/or added.
20
_NO_ACTIVATION_JIT = False
21

22
# Set to True if exporting a model with Same padding via ONNX
23
_EXPORTABLE = False
24

25
# Set to True if wanting to use torch.jit.script on a model
26
_SCRIPTABLE = False
27

28

29
# use torch.scaled_dot_product_attention where possible
30
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
31
if 'TIMM_FUSED_ATTN' in os.environ:
32
    _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])
33
else:
34
    _USE_FUSED_ATTN = 1  # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
35

36

37
def is_no_jit():
38
    return _NO_JIT
39

40

41
class set_no_jit:
42
    def __init__(self, mode: bool) -> None:
43
        global _NO_JIT
44
        self.prev = _NO_JIT
45
        _NO_JIT = mode
46

47
    def __enter__(self) -> None:
48
        pass
49

50
    def __exit__(self, *args: Any) -> bool:
51
        global _NO_JIT
52
        _NO_JIT = self.prev
53
        return False
54

55

56
def is_exportable():
57
    return _EXPORTABLE
58

59

60
class set_exportable:
61
    def __init__(self, mode: bool) -> None:
62
        global _EXPORTABLE
63
        self.prev = _EXPORTABLE
64
        _EXPORTABLE = mode
65

66
    def __enter__(self) -> None:
67
        pass
68

69
    def __exit__(self, *args: Any) -> bool:
70
        global _EXPORTABLE
71
        _EXPORTABLE = self.prev
72
        return False
73

74

75
def is_scriptable():
76
    return _SCRIPTABLE
77

78

79
class set_scriptable:
80
    def __init__(self, mode: bool) -> None:
81
        global _SCRIPTABLE
82
        self.prev = _SCRIPTABLE
83
        _SCRIPTABLE = mode
84

85
    def __enter__(self) -> None:
86
        pass
87

88
    def __exit__(self, *args: Any) -> bool:
89
        global _SCRIPTABLE
90
        _SCRIPTABLE = self.prev
91
        return False
92

93

94
class set_layer_config:
95
    """ Layer config context manager that allows setting all layer config flags at once.
96
    If a flag arg is None, it will not change the current value.
97
    """
98
    def __init__(
99
            self,
100
            scriptable: Optional[bool] = None,
101
            exportable: Optional[bool] = None,
102
            no_jit: Optional[bool] = None,
103
            no_activation_jit: Optional[bool] = None):
104
        global _SCRIPTABLE
105
        global _EXPORTABLE
106
        global _NO_JIT
107
        global _NO_ACTIVATION_JIT
108
        self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
109
        if scriptable is not None:
110
            _SCRIPTABLE = scriptable
111
        if exportable is not None:
112
            _EXPORTABLE = exportable
113
        if no_jit is not None:
114
            _NO_JIT = no_jit
115
        if no_activation_jit is not None:
116
            _NO_ACTIVATION_JIT = no_activation_jit
117

118
    def __enter__(self) -> None:
119
        pass
120

121
    def __exit__(self, *args: Any) -> bool:
122
        global _SCRIPTABLE
123
        global _EXPORTABLE
124
        global _NO_JIT
125
        global _NO_ACTIVATION_JIT
126
        _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
127
        return False
128

129

130
def use_fused_attn(experimental: bool = False) -> bool:
131
    # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
132
    if not _HAS_FUSED_ATTN or _EXPORTABLE:
133
        return False
134
    if experimental:
135
        return _USE_FUSED_ATTN > 1
136
    return _USE_FUSED_ATTN > 0
137

138

139
def set_fused_attn(enable: bool = True, experimental: bool = False):
140
    global _USE_FUSED_ATTN
141
    if not _HAS_FUSED_ATTN:
142
        warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.')
143
        return
144
    if experimental and enable:
145
        _USE_FUSED_ATTN = 2
146
    elif enable:
147
        _USE_FUSED_ATTN = 1
148
    else:
149
        _USE_FUSED_ATTN = 0
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.