pytorch

Форк
0
175 строк · 6.2 Кб
1
import logging
2

3
from typing import Dict, Optional, Tuple, Type
4

5
import sympy
6

7
from torch.utils._sympy.functions import FloorDiv
8

9
log = logging.getLogger(__name__)
10

11
_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {
12
    sympy.Eq: sympy.Eq,
13
    sympy.Ne: sympy.Ne,
14
    sympy.Ge: sympy.Le,
15
    sympy.Gt: sympy.Lt,
16
    sympy.Le: sympy.Ge,
17
    sympy.Lt: sympy.Gt,
18
}
19

20
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
21

22

23
def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:
24
    return _MIRROR_REL_OP.get(type, None)
25

26

27
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
28
#
29
# Returns a tuple of:
30
#   1. The simplified expression
31
#   2. The expression on the right-hand side
32
#
33
# Returns 'None' if it can't reach a state where the only thing in the left
34
# hand side is 'thing'.
35
#
36
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
37
# left-hand side.
38
#
39
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
40
# inequalities.
41
def try_solve(
42
    expr: sympy.Basic,
43
    thing: sympy.Basic,
44
    trials: int = 5,
45
    floordiv_inequality: bool = True,
46
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
47
    mirror = mirror_rel_op(type(expr))
48

49
    # Ignore unsupported expressions:
50
    #   - Those that are not relational operations
51
    #   - Those that don't have a mirror (just avoiding unexpected classes)
52
    if not isinstance(expr, sympy.Rel) or mirror is None:
53
        log.debug("expression with unsupported type: %s", type(expr))
54
        return None
55

56
    lhs_has_thing = expr.lhs.has(thing)
57
    rhs_has_thing = expr.rhs.has(thing)
58

59
    # Give up when 'thing' appears on both sides of the relational expression.
60
    # That is because, as is, we assume the thing we are trying to isolate is
61
    # only on the right-hand side.
62
    if lhs_has_thing and rhs_has_thing:
63
        log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
64
        return None
65

66
    # Try considering both LHS and RHS by mirroring the original expression:
67
    # a < b ==> b > a
68
    expressions = []
69

70
    # Add each version of 'expr' if 'thing' is in its left-hand side.
71
    if lhs_has_thing:
72
        expressions.append(expr)
73
    if rhs_has_thing:
74
        expressions.append(mirror(expr.rhs, expr.lhs))
75

76
    for e in expressions:
77
        if e is None:
78
            continue
79

80
        assert isinstance(e, sympy.Rel)
81

82
        for _ in range(trials):
83
            trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
84
            # Stop if there was no change in this trial.
85
            if trial == e:
86
                break
87
            e = trial  # type: ignore[assignment]
88

89
        # Return if we were able to isolate 'thing' on the left-hand side.
90
        if isinstance(e, sympy.Rel) and e.lhs == thing:
91
            return e, e.rhs
92

93
    return None
94

95

96
def _try_isolate_lhs(
97
    expr: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
98
) -> sympy.Basic:
99
    e = expr
100
    op = type(expr)
101

102
    if isinstance(e, sympy.Rel):
103
        # Move any constants in the left-hand side to the right-hand side.
104
        lhs_not_thing = (
105
            sum([a for a in e.lhs.args if not a.has(thing)])
106
            if isinstance(e.lhs, sympy.Add)
107
            else 0
108
        )
109
        e = op(expr.lhs - lhs_not_thing, expr.rhs - lhs_not_thing)  # type: ignore[attr-defined]
110

111
    # Divide both sides by the factors that don't contain thing.
112
    if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
113
        lhs, rhs = e.args
114
        other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
115

116
        # If we can't tell whether 'other' is negative or positive, we do nothing.
117
        # That is because we don't know whether we have mirror the operation or not.
118
        if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):
119
            # Divide both sides by 'other'.
120
            lhs = lhs / other
121
            rhs = rhs / other
122

123
            # If 'e' is an inequality and 'other' is negative, we have to
124
            # mirror the expression.
125
            if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
126
                op = mirror_rel_op(op)  # type: ignore[assignment]
127

128
            assert op is not None
129
            e = op(lhs, rhs)
130

131
    ################################################################################
132
    # left-hand side is FloorDiv
133
    ################################################################################
134
    #
135
    # Given the expression: a // b op c
136
    # where 'op' is a relational operation, these rules only work if:
137
    #   - b > 0
138
    #   - c is an integer
139
    if (
140
        floordiv_inequality
141
        and isinstance(e, sympy.Rel)
142
        and isinstance(e.lhs, FloorDiv)
143
        and e.lhs.divisor.is_positive
144
        and e.rhs.is_integer
145
    ):
146
        # a // b == expr
147
        # => a >= (b * expr) and a < (b * (expr + 1))
148
        if isinstance(expr, sympy.Eq):
149
            numerator, denominator = e.lhs.args
150
            return sympy.And(
151
                sympy.Ge(numerator, (e.rhs * denominator)),  # type: ignore[arg-type]
152
                sympy.Lt(numerator, ((e.rhs + 1) * denominator)),  # type: ignore[arg-type]
153
            )
154
        # a // b != expr
155
        # => a < (b * expr) or a >= (b * (expr + 1))
156
        if isinstance(expr, sympy.Ne):
157
            numerator, denominator = e.lhs.args
158
            return sympy.Or(
159
                sympy.Lt(numerator, (e.rhs * denominator)),  # type: ignore[arg-type]
160
                sympy.Ge(numerator, ((e.rhs + 1) * denominator)),  # type: ignore[arg-type]
161
            )
162
        # The transformations below only work if b is positive.
163
        # Note: we only have this information for constants.
164
        # a // b > expr  => a >= b * (expr + 1)
165
        # a // b >= expr => a >= b * expr
166
        if isinstance(expr, (sympy.Gt, sympy.Ge)):
167
            quotient = e.rhs if isinstance(expr, sympy.Ge) else (e.rhs + 1)  # type: ignore[arg-type]
168
            return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1]))  # type: ignore[arg-type]
169
        # a // b < expr  => a < b * expr
170
        # a // b <= expr => a < b * (expr + 1)
171
        if isinstance(expr, (sympy.Lt, sympy.Le)):
172
            quotient = e.rhs if isinstance(expr, sympy.Lt) else (e.rhs + 1)  # type: ignore[arg-type]
173
            return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1]))  # type: ignore[arg-type]
174

175
    return e
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.