pytorch

Форк
0
/
optimize_indexing.py 
118 строк · 3.8 Кб
1
import math
2

3
import sympy
4

5
import torch
6
from torch.utils._sympy.value_ranges import ValueRanges
7
from .ir import LoopBody
8
from .utils import dominated_nodes
9

10

11
def val_expressable_in_32_bits(val):
12
    if getattr(val, "is_Boolean", False):
13
        return True
14

15
    if isinstance(val, sympy.Expr):
16
        assert val.is_number
17
        if val.is_Integer or val.is_Boolean:
18
            val = int(val)
19
        else:
20
            val = float(val)
21

22
    # bound within mantissa
23
    if isinstance(val, float):
24
        return val <= (2**24) and val >= -(2**24)
25

26
    if isinstance(val, int):
27
        iinfo = torch.iinfo(torch.int32)
28
        return val <= iinfo.max and val >= iinfo.min
29

30
    raise Exception(f"Unexpected value {val}")
31

32

33
def range_expressable_in_32_bits(range):
34
    return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
35
        range.upper
36
    )
37

38

39
def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals):
40
    # if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
41
    # then it's precision is set for that chain of uses, and we don't need to consider those
42
    # dominated values
43
    def skip_filter(node):
44
        return node.target == "to_dtype" and node.args[2] in (
45
            torch.int32,
46
            torch.float32,
47
            torch.float64,
48
        )
49

50
    # TODO - there are dominated uses whose dtype does not depend on whether
51
    # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
52
    # int32 without changing the output precision of the node. this case hasn't shown up
53
    for dominated in dominated_nodes([node], skip_filter):
54
        if dominated.target in ["store", "output"]:
55
            continue
56

57
        if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
58
            idx = int(dominated.target[len("set_indirect") :])
59
            indirect_var = indirect_vars[idx]
60

61
            # We check that we can compute all the indices it's involved in with int32
62
            for index, expr in indices.items():
63
                if indirect_var in expr.free_symbols:
64
                    index_val = replacement_vals[index]
65

66
                    if math.isinf(index_val.lower) or math.isinf(index_val.upper):
67
                        return
68

69
                    # all indices are integers, so make sure that we
70
                    # use the bounds of integers instead of floats.
71
                    # TODO - not sure if we should be doing int/float casts while tracing,
72
                    # might interfere with sympy.
73

74
                    index_val_int = ValueRanges[sympy.Expr](
75
                        int(index_val.lower), int(index_val.upper)
76
                    )
77
                    if not range_expressable_in_32_bits(index_val_int):
78
                        return
79

80
        if not range_expressable_in_32_bits(bounds[dominated]):
81
            return
82

83
    args = list(node.args)
84
    args[2] = torch.int32
85
    node.args = tuple(args)
86

87

88
def indexing_dtype_strength_reduction(loop_body: LoopBody):
89
    """
90
    Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
91
    intermediaries from int64 to int32
92
    """
93
    bv = loop_body.bounds()
94

95
    int64_dtype_nodes = [
96
        node
97
        for node in loop_body.get_nodes()
98
        if (
99
            node.target == "to_dtype"
100
            and node.args[2] == torch.int64
101
            and node not in bv.unbounded_vars
102
        )
103
    ]
104
    if not int64_dtype_nodes:
105
        return
106

107
    bounds = bv.get_bounds()
108

109
    # TODO - if dominated node of one to_dtype is not expressible in int32,
110
    # we should short circuit another to_dtype node if that node also dominates
111
    for node in int64_dtype_nodes:
112
        try_to_reduce_precision(
113
            node,
114
            bounds,
115
            loop_body.indirect_vars,
116
            loop_body.indexing_exprs,
117
            bv.replacement_vals,
118
        )
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.