pytorch

Форк
0
/
singleton_int.py 
94 строки · 2.9 Кб
1
import sympy
2
from sympy.multipledispatch import dispatch
3

4
__all__ = ["SingletonInt"]
5

6

7
class SingletonInt(sympy.AtomicExpr):
8
    # This is probably not super important unless we are in multiple dispatch
9
    # situations with other more exotic Expr types.
10
    _op_priority = 99999
11

12
    def __new__(cls, *args, coeff=None, **kwargs):
13
        instance = super().__new__(cls, *args, **kwargs)
14
        return instance
15

16
    # The semantics of this class should match that of NestedIntSymNodeImpl in
17
    # c10/core/NestedIntSymNodeImpl.h
18
    def __init__(self, val, *, coeff=1):
19
        self._val = val
20
        self._coeff = coeff
21
        super().__init__()
22

23
    # See NOTE [ Inequalities with nested int ]
24
    def _eval_Eq(self, other):
25
        if (
26
            isinstance(other, SingletonInt)
27
            and other._val == self._val
28
            and self._coeff == other._coeff
29
        ):
30
            return sympy.true
31
        else:
32
            return sympy.false
33

34
    # This is necessary so that calling expr.free_symbols on exprs that contain
35
    # this Singleton does not error
36
    @property
37
    def free_symbols(self):
38
        return set()
39

40
    def __mul__(self, other):
41
        if isinstance(other, SingletonInt):
42
            raise ValueError(
43
                "SingletonInt cannot be multiplied by another SingletonInt"
44
            )
45
        return SingletonInt(self._val, coeff=self._coeff * other)
46

47
    def __rmul__(self, other):
48
        if isinstance(other, SingletonInt):
49
            raise ValueError(
50
                "SingletonInt cannot be multiplied by another SingletonInt"
51
            )
52
        return SingletonInt(self._val, coeff=self._coeff * other)
53

54
    # Make sure we promptly raise an error instead of falling back to building
55
    # an expression tree. There are probably more ops, how can we be exhaustive?
56
    def __add__(self, other):
57
        raise NotImplementedError("NYI")
58

59
    def __sub__(self, other):
60
        raise NotImplementedError("NYI")
61

62
    def __truediv__(self, other):
63
        raise NotImplementedError("NYI")
64

65
    def __floordiv__(self, other):
66
        raise NotImplementedError("NYI")
67

68
    def __mod__(self, other):
69
        raise NotImplementedError("NYI")
70

71

72
# See NOTE [ Inequalities with nested int ]
73
@dispatch(sympy.Integer, SingletonInt)
74
def _eval_is_ge(a, b):
75
    if a < 2:
76
        return sympy.false
77
    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
78

79

80
@dispatch(SingletonInt, sympy.Integer)  # type: ignore[no-redef]
81
def _eval_is_ge(a, b):  # noqa: F811
82
    if b <= 2:
83
        return sympy.true
84
    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
85

86

87
@dispatch(SingletonInt, SingletonInt)  # type: ignore[no-redef]
88
def _eval_is_ge(a, b):  # noqa: F811
89
    if a._val == b._val:
90
        if a._coeff >= b._coeff:
91
            return sympy.true
92
        else:
93
            return sympy.false
94
    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.