pytorch

Форк
0
/
__init__.py 
95 строк · 3.6 Кб
1
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
2
import contextlib
3
from typing import List, Union
4
from warnings import warn
5

6
from torch.backends.cuda import (
7
    can_use_efficient_attention,
8
    can_use_flash_attention,
9
    enable_flash_sdp,
10
    enable_math_sdp,
11
    enable_mem_efficient_sdp,
12
    flash_sdp_enabled,
13
    math_sdp_enabled,
14
    mem_efficient_sdp_enabled,
15
    SDPAParams,
16
)
17

18
__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
19

20
# Note: [SDPA warnings]
21
# TODO: Consider using this for sdpa regardless of subclasses
22
# This only effects users of bias subclasses
23
# If this is set to True, we will warn the user if they are not using the fused kernels
24
# As well, it will raise warnings for all the reasons why the fused kernels can't be run.
25
# To set this to True, run
26
# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True
27
WARN_FOR_UNFUSED_KERNELS = False
28

29

30
from torch._C import _SDPBackend as SDPBackend
31

32
# Hacks for Sphinx documentation:
33
# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
34
SDPBackend = SDPBackend
35
r"""An enum-like class that contains the different backends for scaled dot product attention.
36
    This backend class is designed to be used with the sdpa_kernel context manager.
37
    See :func: torch.nn.attention.sdpa_kernel for more details.
38

39
    ... warning:: This class is in beta and subject to change.
40
"""
41
SDPBackend.__module__ = __name__
42
SDPBackend.__name__ = "SDPBackend"
43

44

45
def _raise_kernel_warnings(params: SDPAParams) -> None:
46
    """
47
    If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings
48
    for all the reasons why the fused kernels can't be run. If using subclasses
49
    """
50
    if WARN_FOR_UNFUSED_KERNELS:
51
        if not can_use_efficient_attention(params):
52
            warn("Efficient attention can't be used because:")
53
            can_use_efficient_attention(params, True)
54
        if not can_use_flash_attention(params):
55
            warn("Flash attention can't be used because:")
56
            can_use_flash_attention(params, True)
57

58

59
@contextlib.contextmanager
60
def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):
61
    r"""
62
    Context manager to select which backend to use for scaled dot product attention.
63

64
    .. warning:: This function is beta and subject to change.
65

66
    Args:
67
        backend (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention.
68

69
    This context manager can be used to select which backend to use for scaled dot product attention.
70
    Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
71
    """
72
    assert isinstance(
73
        backends, (list, SDPBackend)
74
    ), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
75

76
    if isinstance(backends, SDPBackend):
77
        backends = [backends]
78

79
    backends = set(backends)
80
    previous_flash: bool = flash_sdp_enabled()
81
    previous_mem_efficient: bool = mem_efficient_sdp_enabled()
82
    previous_math: bool = math_sdp_enabled()
83
    try:
84
        enable_flash = SDPBackend.FLASH_ATTENTION in backends
85
        enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
86
        enable_math = SDPBackend.MATH in backends
87

88
        enable_flash_sdp(enable_flash)
89
        enable_mem_efficient_sdp(enable_mem_efficient)
90
        enable_math_sdp(enable_math)
91
        yield {}
92
    finally:
93
        enable_flash_sdp(previous_flash)
94
        enable_mem_efficient_sdp(previous_mem_efficient)
95
        enable_math_sdp(previous_math)
96

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.