6
from torch.utils._sympy.value_ranges import ValueRanges
7
from .ir import LoopBody
8
from .utils import dominated_nodes
11
def val_expressable_in_32_bits(val):
12
if getattr(val, "is_Boolean", False):
15
if isinstance(val, sympy.Expr):
17
if val.is_Integer or val.is_Boolean:
22
# bound within mantissa
23
if isinstance(val, float):
24
return val <= (2**24) and val >= -(2**24)
26
if isinstance(val, int):
27
iinfo = torch.iinfo(torch.int32)
28
return val <= iinfo.max and val >= iinfo.min
30
raise Exception(f"Unexpected value {val}")
33
def range_expressable_in_32_bits(range):
34
return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
39
def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals):
40
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
41
# then it's precision is set for that chain of uses, and we don't need to consider those
43
def skip_filter(node):
44
return node.target == "to_dtype" and node.args[2] in (
50
# TODO - there are dominated uses whose dtype does not depend on whether
51
# we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
52
# int32 without changing the output precision of the node. this case hasn't shown up
53
for dominated in dominated_nodes([node], skip_filter):
54
if dominated.target in ["store", "output"]:
57
if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
58
idx = int(dominated.target[len("set_indirect") :])
59
indirect_var = indirect_vars[idx]
61
# We check that we can compute all the indices it's involved in with int32
62
for index, expr in indices.items():
63
if indirect_var in expr.free_symbols:
64
index_val = replacement_vals[index]
66
if math.isinf(index_val.lower) or math.isinf(index_val.upper):
69
# all indices are integers, so make sure that we
70
# use the bounds of integers instead of floats.
71
# TODO - not sure if we should be doing int/float casts while tracing,
72
# might interfere with sympy.
74
index_val_int = ValueRanges[sympy.Expr](
75
int(index_val.lower), int(index_val.upper)
77
if not range_expressable_in_32_bits(index_val_int):
80
if not range_expressable_in_32_bits(bounds[dominated]):
83
args = list(node.args)
85
node.args = tuple(args)
88
def indexing_dtype_strength_reduction(loop_body: LoopBody):
90
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
91
intermediaries from int64 to int32
93
bv = loop_body.bounds()
97
for node in loop_body.get_nodes()
99
node.target == "to_dtype"
100
and node.args[2] == torch.int64
101
and node not in bv.unbounded_vars
104
if not int64_dtype_nodes:
107
bounds = bv.get_bounds()
109
# TODO - if dominated node of one to_dtype is not expressible in int32,
110
# we should short circuit another to_dtype node if that node also dominates
111
for node in int64_dtype_nodes:
112
try_to_reduce_precision(
115
loop_body.indirect_vars,
116
loop_body.indexing_exprs,