pytorch
1import logging2
3from typing import Dict, Optional, Tuple, Type4
5import sympy6
7from torch.utils._sympy.functions import FloorDiv8
9log = logging.getLogger(__name__)10
11_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {12sympy.Eq: sympy.Eq,13sympy.Ne: sympy.Ne,14sympy.Ge: sympy.Le,15sympy.Gt: sympy.Lt,16sympy.Le: sympy.Ge,17sympy.Lt: sympy.Gt,18}
19
20INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)21
22
23def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:24return _MIRROR_REL_OP.get(type, None)25
26
27# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
28#
29# Returns a tuple of:
30# 1. The simplified expression
31# 2. The expression on the right-hand side
32#
33# Returns 'None' if it can't reach a state where the only thing in the left
34# hand side is 'thing'.
35#
36# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
37# left-hand side.
38#
39# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
40# inequalities.
41def try_solve(42expr: sympy.Basic,43thing: sympy.Basic,44trials: int = 5,45floordiv_inequality: bool = True,46) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:47mirror = mirror_rel_op(type(expr))48
49# Ignore unsupported expressions:50# - Those that are not relational operations51# - Those that don't have a mirror (just avoiding unexpected classes)52if not isinstance(expr, sympy.Rel) or mirror is None:53log.debug("expression with unsupported type: %s", type(expr))54return None55
56lhs_has_thing = expr.lhs.has(thing)57rhs_has_thing = expr.rhs.has(thing)58
59# Give up when 'thing' appears on both sides of the relational expression.60# That is because, as is, we assume the thing we are trying to isolate is61# only on the right-hand side.62if lhs_has_thing and rhs_has_thing:63log.debug("thing (%s) found in both sides of expression: %s", thing, expr)64return None65
66# Try considering both LHS and RHS by mirroring the original expression:67# a < b ==> b > a68expressions = []69
70# Add each version of 'expr' if 'thing' is in its left-hand side.71if lhs_has_thing:72expressions.append(expr)73if rhs_has_thing:74expressions.append(mirror(expr.rhs, expr.lhs))75
76for e in expressions:77if e is None:78continue79
80assert isinstance(e, sympy.Rel)81
82for _ in range(trials):83trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)84# Stop if there was no change in this trial.85if trial == e:86break87e = trial # type: ignore[assignment]88
89# Return if we were able to isolate 'thing' on the left-hand side.90if isinstance(e, sympy.Rel) and e.lhs == thing:91return e, e.rhs92
93return None94
95
96def _try_isolate_lhs(97expr: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool98) -> sympy.Basic:99e = expr100op = type(expr)101
102if isinstance(e, sympy.Rel):103# Move any constants in the left-hand side to the right-hand side.104lhs_not_thing = (105sum([a for a in e.lhs.args if not a.has(thing)])106if isinstance(e.lhs, sympy.Add)107else 0108)109e = op(expr.lhs - lhs_not_thing, expr.rhs - lhs_not_thing) # type: ignore[attr-defined]110
111# Divide both sides by the factors that don't contain thing.112if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):113lhs, rhs = e.args114other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])115
116# If we can't tell whether 'other' is negative or positive, we do nothing.117# That is because we don't know whether we have mirror the operation or not.118if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):119# Divide both sides by 'other'.120lhs = lhs / other121rhs = rhs / other122
123# If 'e' is an inequality and 'other' is negative, we have to124# mirror the expression.125if isinstance(e, INEQUALITY_TYPES) and other.is_negative:126op = mirror_rel_op(op) # type: ignore[assignment]127
128assert op is not None129e = op(lhs, rhs)130
131################################################################################132# left-hand side is FloorDiv133################################################################################134#135# Given the expression: a // b op c136# where 'op' is a relational operation, these rules only work if:137# - b > 0138# - c is an integer139if (140floordiv_inequality
141and isinstance(e, sympy.Rel)142and isinstance(e.lhs, FloorDiv)143and e.lhs.divisor.is_positive144and e.rhs.is_integer145):146# a // b == expr147# => a >= (b * expr) and a < (b * (expr + 1))148if isinstance(expr, sympy.Eq):149numerator, denominator = e.lhs.args150return sympy.And(151sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]152sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]153)154# a // b != expr155# => a < (b * expr) or a >= (b * (expr + 1))156if isinstance(expr, sympy.Ne):157numerator, denominator = e.lhs.args158return sympy.Or(159sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]160sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]161)162# The transformations below only work if b is positive.163# Note: we only have this information for constants.164# a // b > expr => a >= b * (expr + 1)165# a // b >= expr => a >= b * expr166if isinstance(expr, (sympy.Gt, sympy.Ge)):167quotient = e.rhs if isinstance(expr, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]168return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]169# a // b < expr => a < b * expr170# a // b <= expr => a < b * (expr + 1)171if isinstance(expr, (sympy.Lt, sympy.Le)):172quotient = e.rhs if isinstance(expr, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]173return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]174
175return e176