pytorch

Форк
0
/
dependencies.py 
506 строк · 16.7 Кб
1
import collections
2
import dataclasses
3
import itertools
4
import logging
5
import re
6
import typing
7
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
8
from unittest.mock import patch
9

10
import sympy
11

12
import torch
13
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
14

15
from .codegen.common import index_prevent_reordering
16
from .utils import (
17
    get_dtype_size,
18
    reduction_num_outputs,
19
    sympy_index_symbol,
20
    sympy_str,
21
    sympy_subs,
22
    VarRanges,
23
)
24
from .virtualized import OpsHandler, ReductionType, V
25

26
log = logging.getLogger(__name__)
27
is_indirect = re.compile(r"indirect|tmp").search
28
Dep = Union["MemoryDep", "StarDep", "WeakDep"]
29

30

31
class MemoryDep(typing.NamedTuple):
32
    name: str
33
    index: sympy.Expr  # type: ignore[assignment]
34
    var_names: Tuple[sympy.Symbol, ...]
35
    size: Tuple[sympy.Expr, ...]
36

37
    def __repr__(self):
38
        return f"MemoryDep({self.name!r}, {self.index}, {self.ranges})"
39

40
    @property
41
    def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
42
        """{c0: 128, c1: 512, ...}"""
43
        return dict(zip(self.var_names, self.size))
44

45
    def get_numel(self) -> sympy.Expr:
46
        if self.is_indirect():
47
            numel = V.graph.get_numel(self.name)
48
        else:
49
            vars = set(self.index.free_symbols)
50
            numel = sympy.Integer(1)
51
            for var, size in zip(self.var_names, self.size):
52
                if var in vars:
53
                    numel = numel * size
54
        return numel
55

56
    def rename(self, renames: Dict[str, str]) -> "MemoryDep":
57
        if self.name in renames:
58
            return MemoryDep(
59
                renames[self.name], self.index, var_names=self.var_names, size=self.size
60
            )
61
        return self
62

63
    def numbytes_hint(self):
64
        return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
65
            V.graph.get_dtype(self.name)
66
        )
67

68
    def has_unbacked_symbols(self):
69
        return len(free_unbacked_symbols(self.get_numel())) > 0
70

71
    def is_contiguous(self) -> bool:
72
        return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
73

74
    def is_scalar(self) -> bool:
75
        if isinstance(self.index, sympy.Symbol):
76
            return self.index not in self.var_names and not self.is_indirect()
77
        return isinstance(self.index, (int, sympy.Integer))
78

79
    def is_indirect(self) -> bool:
80
        return any(is_indirect(v.name) for v in self.index.free_symbols)  # type: ignore[attr-defined]
81

82

83
class StarDep(typing.NamedTuple):
84
    # depends on the entire buffer
85
    name: str
86

87
    @property
88
    def index(self):
89
        raise NotImplementedError("StarDep does not have an index")
90

91
    def get_numel(self) -> sympy.Expr:
92
        return V.graph.get_numel(self.name)
93

94
    def rename(self, renames: Dict[str, str]) -> "StarDep":
95
        if self.name in renames:
96
            return StarDep(renames[self.name])
97
        return self
98

99
    def numbytes_hint(self):
100
        return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
101
            V.graph.get_dtype(self.name)
102
        )
103

104
    def has_unbacked_symbols(self):
105
        return len(free_unbacked_symbols(self.get_numel())) > 0
106

107
    def is_contiguous(self) -> bool:
108
        return False
109

110
    def is_scalar(self) -> bool:
111
        return False
112

113
    def is_indirect(self) -> bool:
114
        return False
115

116

117
# Used for tracking mutation ordering
118
# if A reads a buffer and B mutates it
119
# B must be ordered after A
120
#
121
# It is weak because if it turns out A's read is never used, we can still
122
# eliminate it
123
class WeakDep(typing.NamedTuple):
124
    name: str
125

126
    @property
127
    def index(self):
128
        raise NotImplementedError("WeakDep does not have an index")
129

130
    def get_numel(self) -> sympy.Expr:
131
        return sympy.Integer(1)
132

133
    def rename(self, renames: Dict[str, str]) -> "WeakDep":
134
        if self.name in renames:
135
            return WeakDep(renames[self.name])
136
        return self
137

138
    def numbytes_hint(self):
139
        return 1  # Purely inserted for ordering, not an actual dep
140

141
    def has_unbacked_symbols(self):
142
        return False
143

144
    def is_contiguous(self) -> bool:
145
        return False
146

147

148
class IndexExprDep(typing.NamedTuple):
149
    index: sympy.Expr  # type: ignore[assignment]
150
    var_names: Tuple[sympy.Symbol, ...]
151
    size: Tuple[sympy.Expr, ...]
152

153

154
@dataclasses.dataclass
155
class ReadWrites:
156
    reads: Set[Dep]
157
    writes: Set[Dep]
158
    index_exprs: Set[IndexExprDep]
159
    range_vars: Optional[List[sympy.Expr]] = None
160
    var_ranges: Optional[VarRanges] = None
161
    op_counts: typing.Counter[str] = dataclasses.field(
162
        default_factory=collections.Counter
163
    )
164

165
    def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
166
        return ReadWrites(
167
            {dep.rename(renames) for dep in self.reads},
168
            {dep.rename(renames) for dep in self.writes},
169
            self.index_exprs,
170
            self.range_vars,
171
            self.var_ranges,
172
            op_counts=self.op_counts,
173
        )
174

175
    def with_read(self, dep: Dep) -> "ReadWrites":
176
        assert isinstance(dep, (WeakDep, StarDep))
177
        return ReadWrites(
178
            set.union(self.reads, {dep}),
179
            self.writes,
180
            self.index_exprs,
181
            self.range_vars,
182
            self.var_ranges,
183
            op_counts=self.op_counts,
184
        )
185

186
    def merge(self, other: "ReadWrites"):
187
        reads = set.union(self.reads, other.reads)
188
        writes = set.union(self.writes, other.writes)
189
        index_exprs = set.union(self.index_exprs, other.index_exprs)
190
        op_counts = collections.Counter(self.op_counts)
191
        op_counts.update(other.op_counts)
192
        return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts)
193

194
    @staticmethod
195
    def merge_list(read_writes: List["ReadWrites"]):
196
        all_writes = set.union(*[rw.writes for rw in read_writes])
197
        all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes
198
        all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes])
199

200
        op_counts: typing.Counter[Any] = collections.Counter()
201
        for rw in read_writes:
202
            op_counts.update(rw.op_counts)
203

204
        return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts)
205

206
    def remove_reads(self, rem_reads):
207
        return ReadWrites(
208
            self.reads - rem_reads,
209
            self.writes,
210
            self.index_exprs,
211
            self.range_vars,
212
            self.var_ranges,
213
            op_counts=self.op_counts,
214
        )
215

216
    def reads_and_writes(self):
217
        return itertools.chain(self.reads, self.writes)
218

219

220
class _RecordLoadStoreInner(V.MockHandler):  # type: ignore[name-defined]
221
    def __init__(self, var_ranges: VarRanges, normalize: bool):
222
        super().__init__()
223
        self._reads: Set[Dep] = set()
224
        self._writes: Set[MemoryDep] = set()
225
        self._index_exprs: Set[IndexExprDep] = set()
226
        self._var_ranges: VarRanges = var_ranges
227
        self._normalize: bool = normalize
228

229
    def canonicalize(
230
        self, index: sympy.Expr
231
    ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
232
        if not self._normalize:
233
            sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
234
            var_names = tuple(
235
                k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
236
            )
237
            sizes = tuple(v for v in sizes if v != 1)
238
            return index, var_names, sizes  # type: ignore[return-value]
239

240
        # Try to further simplify the indexes even if simplify_loops didn't
241
        # convert it to the simplest form because of the interference from
242
        # different indexing formulas.
243
        free_symbols = index.free_symbols
244
        var_ranges = {
245
            k: V.graph.sizevars.simplify(v)
246
            for k, v in self._var_ranges.items()
247
            # TODO(jansel): explore this further normalization
248
            # if k in free_symbols
249
        }
250
        index_vars = [*var_ranges.keys()]
251
        sizes = tuple(var_ranges.values())
252
        new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
253
            index_vars,
254
            sizes,
255
            index_prevent_reordering([index], index_vars, sizes),
256
        )
257

258
        # assign new variables each dimension to deal with numbering mismatches
259
        # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
260
        new_vars, add_var = var_builder(canonicalization_prefix())
261
        replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
262
        index = sympy_subs(sympy.expand(index), replacement)
263

264
        new_vars = [*new_vars.keys()]
265
        new_sizes = [*new_sizes]
266
        free_symbols = index.free_symbols
267
        while new_vars and new_vars[-1] not in free_symbols:
268
            # Reduction has last (reduced) dim in its sizes, but
269
            # downstream users won't.  Normalize this away.
270
            new_vars.pop()
271
            new_sizes.pop()
272
        return index, tuple(new_vars), tuple(new_sizes)  # type: ignore[arg-type]
273

274
    def load(self, name: str, index: sympy.Expr) -> str:
275
        self._reads.add(MemoryDep(name, *self.canonicalize(index)))
276
        return f"load({name}, {sympy_str(index)})"
277

278
    def load_seed(self, name: str, index: int):
279
        assert isinstance(index, int)
280
        return self.load(name, sympy.Integer(index))
281

282
    def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
283
        self._writes.add(MemoryDep(name, *self.canonicalize(index)))
284
        return f"store({name}, {sympy_str(index)}, {value}, {mode})"
285

286
    def store_reduction(self, name: str, index, value) -> str:
287
        return self.store(name, index, f"store_reduction({value})")
288

289
    def index_expr(self, index: sympy.Expr, dtype) -> str:
290
        self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
291
        return f"index_expr({sympy_str(index)}, {dtype})"
292

293
    def bucketize(
294
        self,
295
        values,
296
        offsets_name: str,
297
        offsets_size: sympy.Expr,
298
        indexing_dtype: torch.dtype,
299
        right: bool,
300
    ):
301
        self._reads.add(StarDep(offsets_name))
302
        return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
303

304

305
class _OpCounter:
306
    """Shim to count how many times each op is used"""
307

308
    def __init__(self, inner):
309
        super().__init__()
310
        self.parent_handler = inner
311
        self._op_counts: typing.Counter[Any] = collections.Counter()
312

313
    def __getattr__(self, name):
314
        self._op_counts[name] += 1
315
        return getattr(self.parent_handler, name)
316

317

318
class RecordLoadStore(V.KernelFormatterHandler):  # type: ignore[name-defined]
319
    def __init__(self, var_ranges: VarRanges, normalize: bool):
320
        parent_handler = _RecordLoadStoreInner(
321
            var_ranges=var_ranges, normalize=normalize
322
        )
323
        parent_handler = _OpCounter(parent_handler)
324
        super().__init__(parent_handler=parent_handler)
325

326

327
def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
328
    cnt = itertools.count()
329
    var_ranges: VarRanges = dict()
330

331
    def add_var(length: sympy.Expr) -> sympy.Symbol:
332
        v = sympy_index_symbol(f"{prefix}{next(cnt)}")
333
        var_ranges[v] = length
334
        return v
335

336
    return var_ranges, add_var
337

338

339
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
340
    var_ranges, add_var = var_builder(prefix)
341
    args: List[List[sympy.Symbol]] = []
342
    for size in argsizes:
343
        args.append(list(map(add_var, size)))
344
    return args, var_ranges
345

346

347
def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
348
    from .ir import SqueezeView
349

350
    var_ranges, add_var = var_builder(prefix)
351
    args: List[List[sympy.Expr]] = []
352
    new_sizes: List[List[sympy.Expr]] = []
353
    for size in argsizes:
354
        new_size, reindex = SqueezeView.squeezer(size)
355
        new_sizes.append(new_size)
356
        args.append(reindex(list(map(add_var, new_size))))
357
    return args, var_ranges
358

359

360
def extract_read_writes(
361
    fn: Callable[..., Any],
362
    *argsizes: Tuple[sympy.Expr, ...],
363
    normalize: bool = False,
364
    prefix: str = "d",
365
):
366
    args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
367
    rw = RecordLoadStore(var_ranges, normalize=normalize)
368
    with V.set_ops_handler(rw):
369
        fn(*args)
370

371
    if normalize:
372
        range_vars = []  # Number of vars could differ due to normalization
373
    else:
374
        range_vars = list(itertools.chain.from_iterable(args))
375

376
    inner = rw.parent_handler.parent_handler
377
    return ReadWrites(
378
        set(inner._reads),
379
        set(inner._writes),
380
        inner._index_exprs,
381
        range_vars,
382
        var_ranges,
383
        rw.parent_handler._op_counts,
384
    )
385

386

387
def extract_input_node_reduction_ranges(
388
    input_node: "torch._inductor.ir.TensorBox",
389
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
390
    """
391
    Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
392
    It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
393
    In this case, reduction_sizes of the Reduction nodes need to be the same.
394
    Otherwise returns (None, None).
395
    """
396

397
    from .ir import ComputedBuffer, Loops
398

399
    if isinstance(input_node.data, ComputedBuffer):
400
        # Input node has already been realized. Return its size and reduction_size.
401
        size = input_node.get_size()
402
        reduction_size = input_node.get_reduction_size()
403
        if len(reduction_size) > 0:
404
            return (size, reduction_size)
405
        else:
406
            return (None, None)
407

408
    if not isinstance(input_node.data.data, Loops):  # type: ignore[attr-defined]
409
        # Other IRNodes do not have reduction_ranges.
410
        return (None, None)
411

412
    # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
413
    # The current method still uses reduction ranges from the dependent realized node, which is not ideal.
414
    # Is there a way to check whether there are permutations inbetween?
415
    reads = input_node.get_reads()
416
    reduction_size = None
417
    size = None
418
    while reduction_size is None and len(reads) > 0:
419
        seen = set()
420
        new_reads = []
421
        for read in reads:
422
            if not isinstance(read, MemoryDep):
423
                continue
424
            if read.name in seen:
425
                continue
426
            seen.add(read.name)
427
            buffer = V.graph.get_buffer(read.name)
428
            if buffer is None:
429
                continue
430
            if (
431
                isinstance(buffer, ComputedBuffer)
432
                and len(buffer.get_reduction_size()) > 0
433
            ):
434
                if reduction_size is None:
435
                    reduction_size = buffer.get_reduction_size()
436
                    size = buffer.get_size()
437
                elif (
438
                    reduction_size != buffer.get_reduction_size()
439
                    or size != buffer.get_size()
440
                ):
441
                    return (None, None)
442
            else:
443
                new_reads.extend(buffer.get_reads())
444
        if reads == new_reads:
445
            return (size, reduction_size)
446
        else:
447
            reads = new_reads
448
    return (size, reduction_size)
449

450

451
def canonicalization_prefix():
452
    return "c"
453

454

455
# ops handler which computes all the free unbacked symbols for an IR
456
class FreeUnbackedSymbolsOpsHandler:
457
    symbols: Set[sympy.Symbol]
458

459
    def __init__(self):
460
        self.symbols = set()
461

462
    def __getattr__(self, name: str) -> Callable[..., Any]:
463
        def inner(*args, **kwargs):
464
            for a in itertools.chain(args, kwargs.values()):
465
                if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
466
                    self.symbols |= free_unbacked_symbols(a)
467

468
        return inner
469

470
    def indirect_indexing(self, index_var, size, check=True) -> sympy.Symbol:
471
        assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
472
        self.symbols |= free_unbacked_symbols(size)
473
        return sympy_index_symbol(f"({str(index_var)})")
474

475
    def frexp(self, x):
476
        return (None,) * 2
477

478
    def reduction(
479
        self,
480
        dtype: torch.dtype,
481
        src_dtype: torch.dtype,
482
        reduction_type: ReductionType,
483
        value: Union[None, Tuple[None, ...]],
484
    ) -> Union[None, Tuple[None, ...]]:
485
        num_values = reduction_num_outputs(reduction_type)
486
        return (None,) * num_values if num_values > 1 else None
487

488

489
def _typecheck_FreeUnbackedSymbolsOpsHandler(
490
    h: FreeUnbackedSymbolsOpsHandler,
491
) -> OpsHandler[None]:
492
    return h
493

494

495
def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None):
496
    from .ir import FlexibleLayout
497

498
    args = [index, rindex] if rindex is not None else [index]
499
    handler = FreeUnbackedSymbolsOpsHandler()
500
    # NB: I cargo culted the allow_indexing patch here, I don't understand why
501
    # people do this all over
502
    with V.set_ops_handler(handler), patch.object(
503
        FlexibleLayout, "allow_indexing", True
504
    ):
505
        fn(*args)
506
    return handler.symbols
507

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.