pytorch

Форк
0
118 строк · 3.4 Кб
1
"""
2
This is a simple interpreter for Sympy expressions that dispatches to
3
classes following the torch._inductor.virtualized calling convention.
4
For directness, the interpreter takes the handler directly rather than
5
consulting the TLS.  It does not use most of the methods on the full
6
handler; only those with corresponding Sympy expressions.  To see an example
7
of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
8
"""
9

10
import functools
11
from typing import Any, Dict, Union
12

13
import sympy
14
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
15

16
import torch
17
from .functions import (
18
    CleanDiv,
19
    FloorDiv,
20
    IsNonOverlappingAndDenseIndicator,
21
    Mod,
22
    ModularIndexing,
23
    Pow,
24
    Round,
25
    RoundDecimal,
26
    TrueDiv,
27
    Where,
28
)
29

30

31
# TODO: Dedupe this with SYMPY_INTERP
32

33

34
@functools.lru_cache(None)
35
def handlers():
36
    # TODO add CeilDiv (it doesn't appear in the index_expr)
37

38
    # TODO default to some decompositions if the interpreter doesn't have them
39
    # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
40

41
    HANDLERS = {
42
        sympy.Or: "or_",
43
        sympy.And: "and_",
44
        sympy.Eq: "eq",
45
        sympy.Ne: "ne",
46
        sympy.Lt: "lt",
47
        sympy.Gt: "gt",
48
        sympy.Le: "le",
49
        sympy.Ge: "ge",
50
        sympy.Not: "not_",
51
        TrueDiv: "truediv",
52
        FloorDiv: "floordiv",
53
        CleanDiv: "div",
54
        Where: "where",
55
        sympy.Add: "add",
56
        sympy.Mul: "mul",
57
        Pow: "pow",
58
        sympy.Pow: "pow",
59
        Mod: "mod",
60
        sympy.Mod: "mod",
61
        sympy.Abs: "abs",
62
        sympy.log: "log",
63
        sympy.exp: "exp",
64
        sympy.floor: "floor",
65
        sympy.ceiling: "ceil",
66
        sympy.Min: "minimum",
67
        sympy.Max: "maximum",
68
        ModularIndexing: "modular_indexing",
69
        sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
70
        sympy.Piecewise: "piecewise",
71
        IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
72
        Round: "round",
73
        RoundDecimal: "round",
74
    }
75
    for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
76
        HANDLERS[getattr(sympy, name)] = name
77

78
    return HANDLERS
79

80

81
ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
82

83

84
def sympy_interp(
85
    analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
86
):
87
    # Handle base cases
88
    dtype = None
89
    if isinstance(expr, BooleanAtom):
90
        dtype = torch.bool
91
    elif isinstance(expr, sympy.Integer):
92
        dtype = torch.int64
93
    elif isinstance(expr, sympy.Number):
94
        dtype = torch.double
95

96
    if dtype is not None:
97
        return analysis.constant(expr, dtype)
98
    elif isinstance(expr, sympy.Symbol):
99
        return env[expr]
100

101
    # Special cases
102
    if isinstance(expr, sympy.Pow) and isinstance(
103
        expr.args[1], sympy.core.numbers.Half
104
    ):
105
        return analysis.sqrt(sympy_interp(analysis, env, expr.args[0]))
106

107
    # Recursive case
108
    args = [sympy_interp(analysis, env, arg) for arg in expr.args]  # type: ignore[arg-type]
109
    handler_name = handlers()[expr.func]
110
    handler = getattr(analysis, handler_name)
111
    if handler_name in ASSOCIATIVE_OPS:
112
        assert len(args) > 1
113
        acc = handler(args[0], args[1])
114
        for i in range(2, len(args)):
115
            acc = handler(acc, args[i])
116
        return acc
117
    else:
118
        return handler(*args)
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.