pytorch

Форк
0
/
ops_handler.py 
655 строк · 18.5 Кб
1
import itertools
2
from typing import Any, Callable, Generic, Literal, Optional, Tuple, TypeVar, Union
3
from unittest.mock import patch
4

5
import sympy
6
from typing_extensions import Protocol
7

8
import torch
9
import torch.utils._pytree as pytree
10
from torch.fx.graph import inplace_methods, magic_methods
11
from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str
12

13
T = TypeVar("T")
14
StoreMode = Optional[Literal["atomic_add"]]
15
ReductionType = Literal[
16
    "argmax",
17
    "argmin",
18
    "welford_reduce",
19
    "welford_combine",
20
    "any",
21
    "max",
22
    "min",
23
    "prod",
24
    "sum",
25
    "xor_sum",
26
]
27

28

29
def _arg_str(a) -> str:
30
    if isinstance(a, sympy.Expr):
31
        return sympy_str(a)
32
    return str(a)
33

34

35
# NB: This is not done as a parent class, because our ops handlers
36
# implementations make heavy use of __getattr__ magic, and pre-existing
37
# stubs for methods would interfere with this mechanism.
38
#
39
# TODO: A superclass that does desugaring for operations like
40
# reciprocal/square might be useful.
41
class OpsHandler(Protocol[T]):
42
    """
43
    Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
44
    as well as the contract for op handlers.  The type T signifies the domain
45
    of the abstract analysis AKA what all of the functions return / take as arguments
46
    anywhere compute occurs.
47

48
    While these operators are typically dtype polymorphic (e.g., you can use mul
49
    on both integers and floats), they do NOT do promotion and usually return the
50
    same dtype as the input.  You are expected to have handled type promotion
51
    during ATen decompositions.  Most operators correspond exactly to pointwise
52
    operations as defined by torch, so when in doubt about semantics, check the
53
    corresponding torch documentation.  These are all scalar operations (so they
54
    are defined to operate on a single element at a time.)
55

56
    For convenience, many operators take a src_dtype which indicates what the dtype
57
    of the input argument is.  Although in principle this can be derived by an
58
    analysis, providing this for ops where it is useful helps avoid having to repeatedly
59
    recompute dtype in code generation.
60

61
    Note that this often describes a class of static methods, for stateless
62
    ops handlers.
63

64
    Handlers are often defined using ``__getattr__`` metaprogramming, which means
65
    that you cannot declare that a type implements a protocol by inheriting from
66
    it (as the type stubs count as attribute declarations and impede the getattr
67
    magic method from being called).  Instead, define a function that casts an
68
    argument of your type to the protocol, which is sufficient to induce mypy to
69
    test that the protocol is implemented correctly.  Search for ``_typecheck_``
70
    in this file to see some examples.  If you see an obscure error where a
71
    class doesn't implement a Protocol, but mypy doesn't say why, check to see
72
    that ``__getattr__`` is typed correctly (typically, it is not possible to
73
    type ``__getattr__`` without typing it as ``Callable[..., Any]``)
74
    """
75

76
    def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
77
        """Produces a scalar constant of type dtype."""
78
        ...
79

80
    def load_seed(self, name: str, offset: T):
81
        """Computes inductor_prims.lookup_seed."""
82
        ...
83

84
    def rand(self, seed: T, offset: T) -> T:
85
        """Computes inductor_prims.random with mode="rand".  offset has dtype int32."""
86
        ...
87

88
    def randn(self, seed: T, offset: T) -> T:
89
        """Computes inductor_prims.random with mode="randn".  offset has dtype int32."""
90
        ...
91

92
    def randint64(self, seed: T, offset: T, low: T, high: T) -> T:
93
        """Computes inductor_prims.randint.  offset has dtype int32."""
94
        ...
95

96
    def masked(self, mask: T, body: Callable[[], T], other: T) -> T:
97
        """
98
        Computes body, but only perform loads/stores if the boolean mask
99
        evaluates to true.  For example, you would use this if you needed to
100
        perform an indirect load that may not be valid on some elements;
101
        without masking, invalid accesses can cause IMAs.  When mask is true,
102
        the result is the result of body; otherwise it is other.
103

104
        Contrast this with ops.where, which can multiplex between two values
105
        that have been unconditionally computed.
106
        """
107
        ...
108

109
    def where(self, condition: T, input: T, other: T) -> T:
110
        """
111
        Computes torch.where: when condition is true, return input; otherwise return other.
112
        """
113
        ...
114

115
    def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T:
116
        """
117
        Converts a sympy expression into a scalar of type dtype.  expr is typically
118
        an indexing expression, thus the name; however, it can also be used in
119
        non-indexing situations.
120
        """
121
        ...
122

123
    def to_dtype(
124
        self, x: T, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None
125
    ) -> T:
126
        """
127
        Convert x to dtype.  src_dtype can be optionally set to specify what the original
128
        dtype of x was, which can improve code generation (used by torch to(dtype=dtype)).
129
        """
130
        ...
131

132
    def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
133
        """
134
        Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
135
        src_dtype must be the original type of x.
136
        """
137
        ...
138

139
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
140
    # These operations are only available in a "kernel" context.  Check
141
    # torch._inductor.codegen.common.CSEProxy for their typical implementation
142
    # in op handler (routing to their respective implementations in the kernel
143
    # handler)
144
    #
145
    # Importantly, inside a kernel, indexing and mask variables are available
146
    # in scope, which are typically used by sympy.Expr indexing.
147

148
    def indirect_indexing(
149
        self, x: T, size: sympy.Expr, check: bool = True
150
    ) -> sympy.Expr:
151
        """
152
        Convert an integral x into a sympy.Expr that can be subsequently used in
153
        indexing computation.  'size' represents an upper bound on the what valid
154
        indexes can be; when 'check' is True, we check that the x is in bounds.
155

156
        NB: This is typically mandatory to implement for any analysis, because you
157
        MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol).
158
        """
159
        ...
160

161
    def load(self, name: str, index: sympy.Expr) -> T:
162
        """
163
        Load from the memory location 'name', offset by some indexing expression 'index'.
164
        """
165
        ...
166

167
    def store(
168
        self,
169
        name: str,
170
        index: sympy.Expr,
171
        value: T,
172
        mode: StoreMode = None,
173
    ) -> None:
174
        """
175
        Store 'value' to the memory location 'name' offset by 'expr'.  If
176
        specified, 'mode' can require the store to be an atomic addition.
177
        """
178
        ...
179

180
    # TODO: Better explain how the "collective" semantics of these ops;
181
    # remember that the input value is a scalar, you can't reduce on it in the
182
    # traditional sense!
183
    def reduction(
184
        self,
185
        dtype: torch.dtype,
186
        src_dtype: torch.dtype,
187
        reduction_type: ReductionType,
188
        value: T,
189
    ) -> Union[T, Tuple[T, ...]]:
190
        """
191
        Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype',
192
        using 'dtype' as the accumulation dtype for the reduction.  The result
193
        is an intermediate computation which should be stored to the final
194
        location using 'ops.store_reduction'.
195

196
        Valid reduction types are .  For Welford reduction types, this
197
        function returns multiple outputs; consult reduction_num_outputs to
198
        determine the amount in metaprogramming applications.
199
        """
200
        ...
201

202
    # TODO: in practice, this seems to actually return None, but not returning
203
    # a T makes common __getattr__ idioms not type correctly.  Figure out if
204
    # this should be returning something.
205
    def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T:
206
        """
207
        Store the fully accumulated result of 'reduction' to the memory
208
        location 'name' offset by 'expr'.
209
        """
210
        ...
211

212
    def scan(
213
        self, dtype: torch.dtype, combine_fn: Callable[[T, T], T], value: T, init: int
214
    ) -> T:
215
        """
216
        Perform an associative scan on 'value'.
217
        """
218
        # TODO: Improve the description with some pseudocode
219
        ...
220

221
    def bucketize(
222
        self,
223
        values: T,
224
        offsets_name: str,
225
        offsets_size: sympy.Expr,
226
        indexing_dtype: torch.dtype,
227
        right: bool,
228
    ) -> T:
229
        # See [Note: Inductor bucketize op]
230
        ...
231

232
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
233
    # The following ops have semantics that correspond exactly to the torch
234
    # operation with the same corresponding name.
235

236
    def abs(self, x0: T) -> T:
237
        ...
238

239
    def exp(self, x0: T) -> T:
240
        ...
241

242
    def exp2(self, x0: T) -> T:
243
        ...
244

245
    def expm1(self, x0: T) -> T:
246
        ...
247

248
    def sqrt(self, x0: T) -> T:
249
        ...
250

251
    def relu(self, x0: T) -> T:
252
        ...
253

254
    def minimum(self, x0: T, x1: T) -> T:
255
        ...
256

257
    def maximum(self, x0: T, x1: T) -> T:
258
        ...
259

260
    def cos(self, x0: T) -> T:
261
        ...
262

263
    def sin(self, x0: T) -> T:
264
        ...
265

266
    def lgamma(self, x0: T) -> T:
267
        ...
268

269
    def erf(self, x0: T) -> T:
270
        ...
271

272
    def cosh(self, x0: T) -> T:
273
        ...
274

275
    def sinh(self, x0: T) -> T:
276
        ...
277

278
    def acos(self, x0: T) -> T:
279
        ...
280

281
    def acosh(self, x0: T) -> T:
282
        ...
283

284
    def asin(self, x0: T) -> T:
285
        ...
286

287
    def asinh(self, x0: T) -> T:
288
        ...
289

290
    def atan2(self, x0: T, x1: T) -> T:
291
        ...
292

293
    def atan(self, x0: T) -> T:
294
        ...
295

296
    def atanh(self, x0: T) -> T:
297
        ...
298

299
    def copysign(self, x0: T, x1: T) -> T:
300
        ...
301

302
    def erfc(self, x0: T) -> T:
303
        ...
304

305
    def erfinv(self, x0: T) -> T:
306
        ...
307

308
    def frexp(self, x0: T):
309
        ...
310

311
    def hypot(self, x0: T, x1: T) -> T:
312
        ...
313

314
    def log10(self, x0: T) -> T:
315
        ...
316

317
    def nextafter(self, x0: T, x1: T) -> T:
318
        ...
319

320
    def logical_and(self, x0: T, x1: T) -> T:
321
        ...
322

323
    def logical_not(self, x0: T) -> T:
324
        ...
325

326
    def logical_or(self, x0: T, x1: T) -> T:
327
        ...
328

329
    def logical_xor(self, x0: T, x1: T) -> T:
330
        ...
331

332
    def bitwise_and(self, x0: T, x1: T) -> T:
333
        ...
334

335
    def bitwise_not(self, x0: T) -> T:
336
        ...
337

338
    def bitwise_or(self, x0: T, x1: T) -> T:
339
        ...
340

341
    def bitwise_xor(self, x0: T, x1: T) -> T:
342
        ...
343

344
    def bitwise_left_shift(self, x0: T, x1: T) -> T:
345
        ...
346

347
    def bitwise_right_shift(self, x0: T, x1: T) -> T:
348
        ...
349

350
    def rsqrt(self, x0: T) -> T:
351
        ...
352

353
    def log1p(self, x0: T) -> T:
354
        ...
355

356
    def tan(self, x0: T) -> T:
357
        ...
358

359
    def tanh(self, x0: T) -> T:
360
        ...
361

362
    def sigmoid(self, x0: T) -> T:
363
        ...
364

365
    def signbit(self, x0: T) -> T:
366
        ...
367

368
    def fmod(self, x0: T, x1: T) -> T:
369
        ...
370

371
    def log(self, x0: T) -> T:
372
        ...
373

374
    def isinf(self, x0: T) -> T:
375
        ...
376

377
    def isnan(self, x0: T) -> T:
378
        ...
379

380
    def round(self, x0: T) -> T:
381
        ...
382

383
    def floor(self, x0: T) -> T:
384
        ...
385

386
    def sign(self, x0: T) -> T:
387
        ...
388

389
    def to_int(self, x0: T) -> T:
390
        ...
391

392
    def trunc(self, x0: T) -> T:
393
        ...
394

395
    def truncdiv(self, x0: T, x1: T) -> T:
396
        ...
397

398
    def ceil(self, x0: T) -> T:
399
        ...
400

401
    def neg(self, x0: T) -> T:
402
        ...
403

404
    def reciprocal(self, x0: T) -> T:
405
        ...
406

407
    def eq(self, x0: T, x1: T) -> T:
408
        ...
409

410
    def ne(self, x0: T, x1: T) -> T:
411
        ...
412

413
    def lt(self, x0: T, x1: T) -> T:
414
        ...
415

416
    def gt(self, x0: T, x1: T) -> T:
417
        ...
418

419
    def le(self, x0: T, x1: T) -> T:
420
        ...
421

422
    def ge(self, x0: T, x1: T) -> T:
423
        ...
424

425
    def add(self, x0: T, x1: T) -> T:
426
        ...
427

428
    def sub(self, x0: T, x1: T) -> T:
429
        ...
430

431
    def mul(self, x0: T, x1: T) -> T:
432
        ...
433

434
    def floordiv(self, x0: T, x1: T) -> T:
435
        ...
436

437
    def truediv(self, x0: T, x1: T) -> T:
438
        ...
439

440
    def div(self, x0: T, x1: T) -> T:
441
        ...
442

443
    def mod(self, x0: T, x1: T) -> T:
444
        ...
445

446
    def pow(self, x0: T, x1: T) -> T:
447
        ...
448

449
    def and_(self, x0: T, x1: T) -> T:
450
        ...
451

452
    def or_(self, x0: T, x1: T) -> T:
453
        ...
454

455
    def xor(self, x0: T, x1: T) -> T:
456
        ...
457

458
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
459
    # In CUDA, optimized implementations of other mathematical operations are
460
    # offered separately via libdevice for double precision computation (in
461
    # Triton, these go to tl.math rather than tl).  We lower to these
462
    # operators when doing FP64 on CUDA.  Note that some operators
463
    # unconditional go to tl.math.
464
    #
465
    # TODO(ezyang): Is this really the best way to do this?  What if we have
466
    # abs internally route to tl.math automatically when given a double
467
    # precision input?  One reason is that when doing codegen, we often don't
468
    # know what the dtype of the inputs are!  (In principle we do know, but
469
    # for many analyses it's not conveniently available.)
470

471
    def libdevice_abs(self, x0: T) -> T:
472
        ...
473

474
    def libdevice_exp(self, x0: T) -> T:
475
        ...
476

477
    def libdevice_sqrt(self, x0: T) -> T:
478
        ...
479

480
    def libdevice_cos(self, x0: T) -> T:
481
        ...
482

483
    def libdevice_sin(self, x0: T) -> T:
484
        ...
485

486
    def libdevice_sigmoid(self, x0: T) -> T:
487
        ...
488

489
    def libdevice_log(self, x0: T) -> T:
490
        ...
491

492

493
class MockHandler:
494
    def __getattr__(self, name):
495
        if name == "name":
496
            return "MockHandler"
497

498
        def inner(*args, **kwargs):
499
            fargs = [_arg_str(a) for a in args]
500
            fargs.extend(f"{k}={v}" for k, v in kwargs.items())
501
            return f"ops.{name}({', '.join(fargs)})"
502

503
        return inner
504

505
    @staticmethod
506
    def masked(mask, body, other) -> str:
507
        return f"ops.masked({mask}, {body()}, {other})"
508

509
    @staticmethod
510
    def frexp(x):
511
        return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]")
512

513
    @staticmethod
514
    def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
515
        return sympy_index_symbol(f"({str(index_var)})")
516

517
    @classmethod
518
    def _init_cls(cls):
519
        def make_handler(format_string):
520
            @staticmethod  # type: ignore[misc]
521
            def inner(*args):
522
                return format_string.format(*args)
523

524
            return inner
525

526
        for name, format_string in itertools.chain(
527
            magic_methods.items(), inplace_methods.items()
528
        ):
529
            setattr(cls, name, make_handler(format_string))
530

531

532
MockHandler._init_cls()
533

534

535
# Use mypy to check protocol implemented correctly
536
def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
537
    return h
538

539

540
class KernelFormatterHandler:
541
    def __init__(self, parent_handler):
542
        self.parent_handler = parent_handler
543
        self.output = IndentedBuffer(1)
544
        self.var_counter = itertools.count()
545

546
    @staticmethod
547
    def ir_to_string(ir_fn, index, rindex=None) -> str:
548
        from .ir import FlexibleLayout
549
        from .virtualized import V
550

551
        args = [index, rindex] if rindex is not None else [index]
552
        names = ["index", "rindex"] if rindex is not None else ["index"]
553
        formatter = KernelFormatterHandler(MockHandler())
554

555
        with formatter.output.indent(-1):
556
            formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
557
        for name, arg in zip(names, args):
558
            if arg:
559
                lhs = ", ".join(
560
                    [
561
                        str("_" if isinstance(v, (int, sympy.Integer)) else v)
562
                        for v in arg
563
                    ]
564
                )
565
                formatter.output.writeline(f"{lhs} = {name}")
566

567
        with V.set_ops_handler(formatter), patch.object(
568
            FlexibleLayout, "allow_indexing", True
569
        ):
570
            result = ir_fn(*args)
571
            return formatter.getvalue(result)
572

573
    def __getattr__(self, name) -> Callable[..., Any]:
574
        def inner(*args, **kwargs):
575
            line = getattr(self.parent_handler, name)(*args, **kwargs)
576
            if name == "indirect_indexing":
577
                return line
578

579
            def write(line):
580
                # replace line with a new variable name
581
                varname = f"tmp{next(self.var_counter)}"
582
                self.output.writeline(f"{varname} = {line}")
583
                return varname
584

585
            return pytree.tree_map(write, line)
586

587
        return inner
588

589
    def reduction(
590
        self,
591
        dtype: torch.dtype,
592
        src_dtype: torch.dtype,
593
        reduction_type: ReductionType,
594
        value: Union[str, Tuple[str, ...]],
595
    ) -> Union[str, Tuple[str, ...]]:
596
        line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value)
597
        num_values = reduction_num_outputs(reduction_type)
598
        varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)]
599
        self.output.writeline(f"{','.join(varnames)} = {line}")
600
        return tuple(varnames) if num_values > 1 else varnames[0]
601

602
    def getvalue(self, result):
603
        self.output.writeline(f"return {result}")
604
        return self.output.getvalue()
605

606

607
# Use mypy to check protocol implemented correctly
608
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
609
    return h
610

611

612
class WrapperHandler(Generic[T]):
613
    def __init__(self, inner: OpsHandler[T]):
614
        self._inner = inner
615

616
    def __getattr__(self, item):
617
        return getattr(self._inner, item)
618

619

620
# Use mypy to check protocol implemented correctly
621
def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]:
622
    return h
623

624

625
class OpCounterCSE:
626
    """Shim to count how many ops are used"""
627

628
    def __init__(self, inner):
629
        super().__init__()
630
        self.parent_handler = inner
631
        self.op_count = 0
632
        self.var_names = {}
633

634
    def __getattr__(self, name):
635
        def inner(*args, **kwargs):
636
            val = getattr(self.parent_handler, name)(*args, **kwargs)
637
            if name == "indirect_indexing":
638
                return val
639

640
            def count(val):
641
                if val not in self.var_names:
642
                    varname = f"tmp{self.op_count}"
643
                    self.op_count += 1
644
                    self.var_names[val] = varname
645
                    return varname
646
                else:
647
                    return self.var_names[val]
648

649
            return pytree.tree_map(count, val)
650

651
        return inner
652

653

654
def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]:
655
    return h
656

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.