pytorch

Форк
0
190 строк · 5.9 Кб
1
# mypy: allow-untyped-defs
2
"""
3
This is a simple interpreter for Sympy expressions that dispatches to
4
classes following the torch._inductor.virtualized calling convention.
5
For directness, the interpreter takes the handler directly rather than
6
consulting the TLS.  It does not use most of the methods on the full
7
handler; only those with corresponding Sympy expressions.  To see an example
8
of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
9
"""
10

11
import functools
12
import logging
13
from typing import Any, Dict, Union
14

15
import sympy
16
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
17

18
import torch
19

20
from .functions import (
21
    CeilToInt,
22
    CleanDiv,
23
    FloatPow,
24
    FloatTrueDiv,
25
    FloorDiv,
26
    FloorToInt,
27
    Identity,
28
    IntTrueDiv,
29
    IsNonOverlappingAndDenseIndicator,
30
    Max,
31
    Min,
32
    Mod,
33
    ModularIndexing,
34
    PowByNatural,
35
    PythonMod,
36
    RoundDecimal,
37
    RoundToInt,
38
    ToFloat,
39
    TruncToFloat,
40
    TruncToInt,
41
    Where,
42
)
43

44

45
log = logging.getLogger(__name__)
46

47

48
# TODO: Dedupe this with SYMPY_INTERP
49

50

51
@functools.lru_cache(None)
52
def handlers():
53
    # TODO add CeilDiv (it doesn't appear in the index_expr)
54

55
    # TODO default to some decompositions if the interpreter doesn't have them
56
    # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
57

58
    HANDLERS = {
59
        sympy.Or: "or_",
60
        sympy.And: "and_",
61
        sympy.Eq: "eq",
62
        sympy.Ne: "ne",
63
        sympy.Lt: "lt",
64
        sympy.Gt: "gt",
65
        sympy.Le: "le",
66
        sympy.Ge: "ge",
67
        sympy.Not: "not_",
68
        IntTrueDiv: "int_truediv",
69
        FloatTrueDiv: "truediv",
70
        FloorDiv: "floordiv",
71
        CleanDiv: "floordiv",  # TODO: hmm?
72
        TruncToFloat: "trunc",
73
        Where: "where",
74
        sympy.Add: "add",
75
        sympy.Mul: "mul",
76
        FloatPow: "pow",
77
        PowByNatural: "pow_by_natural",
78
        # sympy simplifies x * x into Pow(x, 2), so we need to handle this.
79
        # Do NOT use builtin Pow for floats
80
        # TODO: There is a hazard here, if we have float * float it will
81
        # also get turned into Pow(float, 2) but we don't want this because
82
        # pow_by_natural is assumed to only be integers.  Probably the fix is
83
        # to add a FloatMul to impede this optimization
84
        sympy.Pow: "pow_by_natural",
85
        Mod: "mod",
86
        PythonMod: "mod",  # TODO: this is wrong
87
        # TODO: Inductor can generate these, but it's ill-specified which
88
        # semantics were intended here.  Needs to be cleaned up along with
89
        # FloorDiv in a bigger cleanup
90
        sympy.Mod: "mod",
91
        sympy.Abs: "abs",
92
        sympy.log: "log",
93
        sympy.exp: "exp",
94
        sympy.Min: "minimum",
95
        sympy.Max: "maximum",
96
        Min: "minimum",
97
        Max: "maximum",
98
        ModularIndexing: "modular_indexing",
99
        sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
100
        sympy.Piecewise: "piecewise",
101
        Identity: "identity",
102
        IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
103
        RoundDecimal: "round_decimal",
104
    }
105
    for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
106
        HANDLERS[getattr(sympy, name)] = name
107

108
    return HANDLERS
109

110

111
ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
112

113

114
def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
115
    # Special cases
116
    if isinstance(expr, sympy.Pow) and isinstance(
117
        expr.args[1], sympy.core.numbers.Half
118
    ):
119
        return analysis.sqrt(args[0])
120
    if isinstance(expr, ToFloat):
121
        return analysis.to_dtype(args[0], torch.float64)
122

123
    # These handlers are special because they take an extra dtype argument
124
    # specifying what they should convert to, and we need to appropriately set
125
    # this up when we convert from Sympy.  A reasonable default when you
126
    # are translating is to conservatively do int64, and then narrow these
127
    # arguments later when you discover you can narrow the index range.  But
128
    # if you already know that 32-bit indexing is OK, you can directly do the
129
    # sympy translation with index_dtype=torch.int32
130
    INDEX_DTYPE_HANDLERS = {
131
        TruncToInt: "trunc_to_int",
132
        sympy.floor: "floor_to_int",
133
        sympy.ceiling: "ceil_to_int",
134
        FloorToInt: "floor_to_int",
135
        CeilToInt: "ceil_to_int",
136
        RoundToInt: "round_to_int",
137
    }
138
    if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
139
        return getattr(analysis, handler_name)(*args, index_dtype)
140

141
    if hasattr(expr.func, "_torch_handler_name"):
142
        handler_name = expr.func._torch_handler_name
143
    else:
144
        handler_name = handlers()[expr.func]
145
    handler = getattr(analysis, handler_name)
146
    try:
147
        if handler_name in ASSOCIATIVE_OPS:
148
            assert len(args) > 1
149
            acc = handler(args[0], args[1])
150
            for i in range(2, len(args)):
151
                acc = handler(acc, args[i])
152
            log.debug("%s(%s) -> %s", handler_name, args, acc)
153
            return acc
154
        else:
155
            r = handler(*args)
156
            log.debug("%s(%s) -> %s", handler_name, args, r)
157
            return r
158
    except Exception:
159
        log.warning("failed while executing %s(%s)", handler_name, args)
160
        raise
161

162

163
def sympy_interp(
164
    analysis,
165
    env: Dict[sympy.Symbol, Any],
166
    expr: Union[sympy.Expr, SympyBoolean],
167
    *,
168
    index_dtype=torch.int64,
169
):
170
    # Handle base cases
171
    dtype = None
172
    if isinstance(expr, BooleanAtom):
173
        dtype = torch.bool
174
    elif isinstance(expr, sympy.Integer):
175
        dtype = torch.int64
176
    elif isinstance(expr, sympy.Number):
177
        dtype = torch.double
178

179
    if dtype is not None:
180
        return analysis.constant(expr, dtype)
181
    elif isinstance(expr, sympy.Symbol):
182
        return env[expr]
183

184
    # Recursive case
185
    return _run_sympy_handler(
186
        analysis,
187
        [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
188
        expr,
189
        index_dtype=index_dtype,
190
    )  # type: ignore[arg-type]
191

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.