7
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
8
from unittest.mock import patch
13
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
15
from .codegen.common import index_prevent_reordering
18
reduction_num_outputs,
24
from .virtualized import OpsHandler, ReductionType, V
26
log = logging.getLogger(__name__)
27
is_indirect = re.compile(r"indirect|tmp").search
28
Dep = Union["MemoryDep", "StarDep", "WeakDep"]
31
class MemoryDep(typing.NamedTuple):
34
var_names: Tuple[sympy.Symbol, ...]
35
size: Tuple[sympy.Expr, ...]
38
return f"MemoryDep({self.name!r}, {self.index}, {self.ranges})"
41
def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
42
"""{c0: 128, c1: 512, ...}"""
43
return dict(zip(self.var_names, self.size))
45
def get_numel(self) -> sympy.Expr:
46
if self.is_indirect():
47
numel = V.graph.get_numel(self.name)
49
vars = set(self.index.free_symbols)
50
numel = sympy.Integer(1)
51
for var, size in zip(self.var_names, self.size):
56
def rename(self, renames: Dict[str, str]) -> "MemoryDep":
57
if self.name in renames:
59
renames[self.name], self.index, var_names=self.var_names, size=self.size
63
def numbytes_hint(self):
64
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
65
V.graph.get_dtype(self.name)
68
def has_unbacked_symbols(self):
69
return len(free_unbacked_symbols(self.get_numel())) > 0
71
def is_contiguous(self) -> bool:
72
return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
74
def is_scalar(self) -> bool:
75
if isinstance(self.index, sympy.Symbol):
76
return self.index not in self.var_names and not self.is_indirect()
77
return isinstance(self.index, (int, sympy.Integer))
79
def is_indirect(self) -> bool:
80
return any(is_indirect(v.name) for v in self.index.free_symbols)
83
class StarDep(typing.NamedTuple):
89
raise NotImplementedError("StarDep does not have an index")
91
def get_numel(self) -> sympy.Expr:
92
return V.graph.get_numel(self.name)
94
def rename(self, renames: Dict[str, str]) -> "StarDep":
95
if self.name in renames:
96
return StarDep(renames[self.name])
99
def numbytes_hint(self):
100
return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
101
V.graph.get_dtype(self.name)
104
def has_unbacked_symbols(self):
105
return len(free_unbacked_symbols(self.get_numel())) > 0
107
def is_contiguous(self) -> bool:
110
def is_scalar(self) -> bool:
113
def is_indirect(self) -> bool:
123
class WeakDep(typing.NamedTuple):
128
raise NotImplementedError("WeakDep does not have an index")
130
def get_numel(self) -> sympy.Expr:
131
return sympy.Integer(1)
133
def rename(self, renames: Dict[str, str]) -> "WeakDep":
134
if self.name in renames:
135
return WeakDep(renames[self.name])
138
def numbytes_hint(self):
141
def has_unbacked_symbols(self):
144
def is_contiguous(self) -> bool:
148
class IndexExprDep(typing.NamedTuple):
150
var_names: Tuple[sympy.Symbol, ...]
151
size: Tuple[sympy.Expr, ...]
154
@dataclasses.dataclass
158
index_exprs: Set[IndexExprDep]
159
range_vars: Optional[List[sympy.Expr]] = None
160
var_ranges: Optional[VarRanges] = None
161
op_counts: typing.Counter[str] = dataclasses.field(
162
default_factory=collections.Counter
165
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
167
{dep.rename(renames) for dep in self.reads},
168
{dep.rename(renames) for dep in self.writes},
172
op_counts=self.op_counts,
175
def with_read(self, dep: Dep) -> "ReadWrites":
176
assert isinstance(dep, (WeakDep, StarDep))
178
set.union(self.reads, {dep}),
183
op_counts=self.op_counts,
186
def merge(self, other: "ReadWrites"):
187
reads = set.union(self.reads, other.reads)
188
writes = set.union(self.writes, other.writes)
189
index_exprs = set.union(self.index_exprs, other.index_exprs)
190
op_counts = collections.Counter(self.op_counts)
191
op_counts.update(other.op_counts)
192
return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts)
195
def merge_list(read_writes: List["ReadWrites"]):
196
all_writes = set.union(*[rw.writes for rw in read_writes])
197
all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes
198
all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes])
200
op_counts: typing.Counter[Any] = collections.Counter()
201
for rw in read_writes:
202
op_counts.update(rw.op_counts)
204
return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts)
206
def remove_reads(self, rem_reads):
208
self.reads - rem_reads,
213
op_counts=self.op_counts,
216
def reads_and_writes(self):
217
return itertools.chain(self.reads, self.writes)
220
class _RecordLoadStoreInner(V.MockHandler):
221
def __init__(self, var_ranges: VarRanges, normalize: bool):
223
self._reads: Set[Dep] = set()
224
self._writes: Set[MemoryDep] = set()
225
self._index_exprs: Set[IndexExprDep] = set()
226
self._var_ranges: VarRanges = var_ranges
227
self._normalize: bool = normalize
230
self, index: sympy.Expr
231
) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
232
if not self._normalize:
233
sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
235
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
237
sizes = tuple(v for v in sizes if v != 1)
238
return index, var_names, sizes
243
free_symbols = index.free_symbols
245
k: V.graph.sizevars.simplify(v)
246
for k, v in self._var_ranges.items()
250
index_vars = [*var_ranges.keys()]
251
sizes = tuple(var_ranges.values())
252
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
255
index_prevent_reordering([index], index_vars, sizes),
260
new_vars, add_var = var_builder(canonicalization_prefix())
261
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
262
index = sympy_subs(sympy.expand(index), replacement)
264
new_vars = [*new_vars.keys()]
265
new_sizes = [*new_sizes]
266
free_symbols = index.free_symbols
267
while new_vars and new_vars[-1] not in free_symbols:
272
return index, tuple(new_vars), tuple(new_sizes)
274
def load(self, name: str, index: sympy.Expr) -> str:
275
self._reads.add(MemoryDep(name, *self.canonicalize(index)))
276
return f"load({name}, {sympy_str(index)})"
278
def load_seed(self, name: str, index: int):
279
assert isinstance(index, int)
280
return self.load(name, sympy.Integer(index))
282
def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
283
self._writes.add(MemoryDep(name, *self.canonicalize(index)))
284
return f"store({name}, {sympy_str(index)}, {value}, {mode})"
286
def store_reduction(self, name: str, index, value) -> str:
287
return self.store(name, index, f"store_reduction({value})")
289
def index_expr(self, index: sympy.Expr, dtype) -> str:
290
self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
291
return f"index_expr({sympy_str(index)}, {dtype})"
297
offsets_size: sympy.Expr,
298
indexing_dtype: torch.dtype,
301
self._reads.add(StarDep(offsets_name))
302
return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
306
"""Shim to count how many times each op is used"""
308
def __init__(self, inner):
310
self.parent_handler = inner
311
self._op_counts: typing.Counter[Any] = collections.Counter()
313
def __getattr__(self, name):
314
self._op_counts[name] += 1
315
return getattr(self.parent_handler, name)
318
class RecordLoadStore(V.KernelFormatterHandler):
319
def __init__(self, var_ranges: VarRanges, normalize: bool):
320
parent_handler = _RecordLoadStoreInner(
321
var_ranges=var_ranges, normalize=normalize
323
parent_handler = _OpCounter(parent_handler)
324
super().__init__(parent_handler=parent_handler)
327
def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
328
cnt = itertools.count()
329
var_ranges: VarRanges = dict()
331
def add_var(length: sympy.Expr) -> sympy.Symbol:
332
v = sympy_index_symbol(f"{prefix}{next(cnt)}")
333
var_ranges[v] = length
336
return var_ranges, add_var
339
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
340
var_ranges, add_var = var_builder(prefix)
341
args: List[List[sympy.Symbol]] = []
342
for size in argsizes:
343
args.append(list(map(add_var, size)))
344
return args, var_ranges
347
def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
348
from .ir import SqueezeView
350
var_ranges, add_var = var_builder(prefix)
351
args: List[List[sympy.Expr]] = []
352
new_sizes: List[List[sympy.Expr]] = []
353
for size in argsizes:
354
new_size, reindex = SqueezeView.squeezer(size)
355
new_sizes.append(new_size)
356
args.append(reindex(list(map(add_var, new_size))))
357
return args, var_ranges
360
def extract_read_writes(
361
fn: Callable[..., Any],
362
*argsizes: Tuple[sympy.Expr, ...],
363
normalize: bool = False,
366
args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
367
rw = RecordLoadStore(var_ranges, normalize=normalize)
368
with V.set_ops_handler(rw):
374
range_vars = list(itertools.chain.from_iterable(args))
376
inner = rw.parent_handler.parent_handler
383
rw.parent_handler._op_counts,
387
def extract_input_node_reduction_ranges(
388
input_node: "torch._inductor.ir.TensorBox",
389
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
391
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
392
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
393
In this case, reduction_sizes of the Reduction nodes need to be the same.
394
Otherwise returns (None, None).
397
from .ir import ComputedBuffer, Loops
399
if isinstance(input_node.data, ComputedBuffer):
401
size = input_node.get_size()
402
reduction_size = input_node.get_reduction_size()
403
if len(reduction_size) > 0:
404
return (size, reduction_size)
408
if not isinstance(input_node.data.data, Loops):
415
reads = input_node.get_reads()
416
reduction_size = None
418
while reduction_size is None and len(reads) > 0:
422
if not isinstance(read, MemoryDep):
424
if read.name in seen:
427
buffer = V.graph.get_buffer(read.name)
431
isinstance(buffer, ComputedBuffer)
432
and len(buffer.get_reduction_size()) > 0
434
if reduction_size is None:
435
reduction_size = buffer.get_reduction_size()
436
size = buffer.get_size()
438
reduction_size != buffer.get_reduction_size()
439
or size != buffer.get_size()
443
new_reads.extend(buffer.get_reads())
444
if reads == new_reads:
445
return (size, reduction_size)
448
return (size, reduction_size)
451
def canonicalization_prefix():
456
class FreeUnbackedSymbolsOpsHandler:
457
symbols: Set[sympy.Symbol]
462
def __getattr__(self, name: str) -> Callable[..., Any]:
463
def inner(*args, **kwargs):
464
for a in itertools.chain(args, kwargs.values()):
465
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
466
self.symbols |= free_unbacked_symbols(a)
470
def indirect_indexing(self, index_var, size, check=True) -> sympy.Symbol:
471
assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
472
self.symbols |= free_unbacked_symbols(size)
473
return sympy_index_symbol(f"({str(index_var)})")
481
src_dtype: torch.dtype,
482
reduction_type: ReductionType,
483
value: Union[None, Tuple[None, ...]],
484
) -> Union[None, Tuple[None, ...]]:
485
num_values = reduction_num_outputs(reduction_type)
486
return (None,) * num_values if num_values > 1 else None
489
def _typecheck_FreeUnbackedSymbolsOpsHandler(
490
h: FreeUnbackedSymbolsOpsHandler,
491
) -> OpsHandler[None]:
495
def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None):
496
from .ir import FlexibleLayout
498
args = [index, rindex] if rindex is not None else [index]
499
handler = FreeUnbackedSymbolsOpsHandler()
502
with V.set_ops_handler(handler), patch.object(
503
FlexibleLayout, "allow_indexing", True
506
return handler.symbols