pytorch

Форк
0
/
__init__.py 
140 строк · 4.0 Кб
1
from typing import Any, Dict, List, Optional
2

3
import torch.fx
4
import torch.utils._pytree as pytree
5

6
__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]
7

8

9
def compile(
10
    gm: torch.fx.GraphModule,
11
    example_inputs: List[torch.Tensor],
12
    options: Optional[Dict[str, Any]] = None,
13
):
14
    """
15
    Compile a given FX graph with TorchInductor.  This allows compiling
16
    FX graphs captured without using TorchDynamo.
17

18
    Args:
19
        gm: The FX graph to compile.
20
        example_inputs:  List of tensor inputs.
21
        options:  Optional dict of config options.  See `torch._inductor.config`.
22

23
    Returns:
24
        Callable with same behavior as gm but faster.
25
    """
26
    from .compile_fx import compile_fx
27

28
    return compile_fx(gm, example_inputs, config_patches=options)
29

30

31
def aot_compile(
32
    gm: torch.fx.GraphModule,
33
    example_inputs: List[torch.Tensor],
34
    options: Optional[Dict[str, Any]] = None,
35
) -> str:
36
    """
37
    Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
38

39
    Args:
40
        gm: The FX graph to compile.
41
        example_inputs:  List of tensor inputs.
42
        options:  Optional dict of config options.  See `torch._inductor.config`.
43

44
    Returns:
45
        Path to the generated shared library
46
    """
47
    from .compile_fx import compile_fx_aot
48

49
    # We will serialize the pytree info into the .so as constant strings
50
    serialized_in_spec = ""
51
    serialized_out_spec = ""
52
    if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen):
53
        codegen = gm.graph._codegen
54
        gm.graph._codegen = torch.fx.graph.CodeGen()
55
        gm.recompile()
56

57
        if codegen.pytree_info.in_spec is not None:
58
            serialized_in_spec = pytree.treespec_dumps(codegen.pytree_info.in_spec)
59

60
        if codegen.pytree_info.out_spec is not None:
61
            serialized_out_spec = pytree.treespec_dumps(codegen.pytree_info.out_spec)
62

63
    options = (
64
        {
65
            "aot_inductor.serialized_in_spec": serialized_in_spec,
66
            "aot_inductor.serialized_out_spec": serialized_out_spec,
67
        }
68
        if options is None
69
        else {
70
            **options,
71
            "aot_inductor.serialized_in_spec": serialized_in_spec,
72
            "aot_inductor.serialized_out_spec": serialized_out_spec,
73
        }
74
    )
75

76
    return compile_fx_aot(
77
        gm,
78
        example_inputs,
79
        config_patches=options,
80
    )
81

82

83
def list_mode_options(
84
    mode: Optional[str] = None, dynamic: Optional[bool] = None
85
) -> Dict[str, Any]:
86
    r"""Returns a dictionary describing the optimizations that each of the available
87
    modes passed to `torch.compile()` performs.
88

89
    Args:
90
        mode (str, optional): The mode to return the optimizations for.
91
        If None, returns optimizations for all modes
92
        dynamic (bool, optional): Whether dynamic shape is enabled.
93

94
    Example::
95
        >>> torch._inductor.list_mode_options()
96
    """
97

98
    mode_options: Dict[str, Dict[str, bool]] = {
99
        "default": {},
100
        # enable cudagraphs
101
        "reduce-overhead": {
102
            "triton.cudagraphs": True,
103
        },
104
        # enable max-autotune
105
        "max-autotune-no-cudagraphs": {
106
            "max_autotune": True,
107
        },
108
        # enable max-autotune
109
        # enable cudagraphs
110
        "max-autotune": {
111
            "max_autotune": True,
112
            "triton.cudagraphs": True,
113
        },
114
    }
115
    return mode_options[mode] if mode else mode_options  # type: ignore[return-value]
116

117

118
def list_options() -> List[str]:
119
    r"""Returns a dictionary describing the optimizations and debug configurations
120
    that are available to `torch.compile()`.
121

122
    The options are documented in `torch._inductor.config`.
123

124
    Example::
125

126
        >>> torch._inductor.list_options()
127
    """
128

129
    from torch._inductor import config
130

131
    current_config: Dict[str, Any] = config.shallow_copy_dict()
132

133
    return list(current_config.keys())
134

135

136
def cudagraph_mark_step_begin():
137
    "Indicates that a new iteration of inference or training is about to begin."
138
    from .cudagraph_trees import mark_step_begin
139

140
    mark_step_begin()
141

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.