pytorch

Форк
0
495 строк · 14.3 Кб
1
from __future__ import annotations
2

3
import functools
4
import logging
5
from typing import cast, List, Optional, Sequence, Tuple, TypedDict
6

7
import torch
8
from .. import config, ir
9
from ..ir import TensorBox
10

11
from ..lowering import (
12
    add_layout_constraint,
13
    constrain_to_fx_strides,
14
    lowerings as L,
15
    register_lowering,
16
)
17
from ..select_algorithm import (
18
    autotune_select_algorithm,
19
    ExternKernelChoice,
20
    TritonTemplate,
21
)
22
from ..utils import (
23
    ceildiv,
24
    is_ones,
25
    is_zeros,
26
    pad_listlike,
27
    sympy_product,
28
    use_triton_template,
29
)
30
from ..virtualized import V
31
from .mm_common import filtered_configs
32

33
log = logging.getLogger(__name__)
34

35

36
aten = torch.ops.aten
37

38

39
def conv_grid(n, c, h, w, meta):
40
    return (
41
        ceildiv(n * h * w, meta["BLOCK_M"]),
42
        ceildiv(c, meta["BLOCK_N"]),
43
        meta["GROUPS"],
44
    )
45

46

47
# List of dictionaries to store the kernel configs. Configs that evaluate to true
48
# will be utilised on the target platform
49
kernel_configs = [
50
    # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
51
    {"config": (64, 256, 16, 2, 4), "cond": True},
52
    {"config": (256, 64, 16, 2, 4), "cond": True},
53
    {"config": (1024, 16, 16, 1, 8), "cond": True},
54
    {"config": (128, 128, 32, 2, 8), "cond": True},
55
    {"config": (64, 64, 32, 2, 4), "cond": True},
56
    {"config": (64, 256, 32, 2, 8), "cond": True},
57
    {"config": (256, 64, 32, 2, 8), "cond": True},
58
]
59

60
# Create filtered list of configs based on conv
61
platform_configs = tuple(
62
    cast(Tuple[int, int, int, int, int], config["config"])
63
    for config in kernel_configs
64
    if config["cond"]
65
)
66

67
# On ROCm convert num_stages to 1 as pipelining provides no benefit
68
if torch.version.hip:
69
    platform_configs = tuple(
70
        (config[0], config[1], config[2], 1, config[4]) for config in platform_configs
71
    )
72

73
conv_configs = functools.partial(
74
    filtered_configs,
75
    configs=platform_configs,
76
)
77

78
LOOP_BODY = """
79
        idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
80
        idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
81
        idx_x_c = tl.arange(0, BLOCK_K) + k
82

83
        x_ptrs = x_base + (
84
            (idx_x_h * stride_xh)[:, None]
85
            + (idx_x_w * stride_xw)[:, None]
86
            + (idx_x_c * stride_xc)[None, :]
87
        )
88
        mask_x = (
89
            (idx_n < BATCH)[:, None]
90
            & (idx_x_h >= 0)[:, None]
91
            & (idx_x_h < IN_H)[:, None]
92
            & (idx_x_w >= 0)[:, None]
93
            & (idx_x_w < IN_W)[:, None]
94
            & (idx_x_c < GROUP_IN_C)[None, :]
95
        )
96
        matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
97

98
        w_ptrs = w_base + (
99
            (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
100
        )
101
        mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
102
        matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
103
        acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
104
"""
105

106
"""
107
This is a relatively simple conv implementation that can likely be
108
improved.  Many alternate conv versions can be found here:
109
https://github.com/pytorch/torchdynamo/pull/971
110
"""
111
conv2d_template = TritonTemplate(
112
    name="convolution",
113
    grid=conv_grid,
114
    source=r"""
115
{{def_kernel("X", "W")}}
116
    # Tensor dimensions
117
    BATCH = {{size("X", 0)}}
118
    IN_C = {{size("X", 1)}}
119
    IN_H = {{size("X", 2)}}
120
    IN_W = {{size("X", 3)}}
121
    OUT_C = {{size(None, 1)}}
122
    OUT_H = {{size(None, 2)}}
123
    OUT_W = {{size(None, 3)}}
124

125
    # Strides:
126
    stride_xn = {{stride("X", 0)}}
127
    stride_xc = {{stride("X", 1)}}
128
    stride_xh = {{stride("X", 2)}}
129
    stride_xw = {{stride("X", 3)}}
130
    stride_wc_out = {{stride("W", 0)}}
131
    stride_wc_in = {{stride("W", 1)}}
132
    stride_wh = {{stride("W", 2)}}
133
    stride_ww = {{stride("W", 3)}}
134

135
    nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
136
    idx_y_w = nhw % OUT_W
137
    nh = nhw // OUT_W
138
    idx_y_h = nh % OUT_H
139
    idx_n = nh // OUT_H
140
    idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
141

142
{% if GROUPS == 1 %}
143
    group = 0
144
    GROUP_IN_C = IN_C
145
    GROUP_OUT_C = OUT_C
146
{% else %}
147
    group = tl.program_id(2)
148
    GROUP_IN_C = IN_C // GROUPS
149
    GROUP_OUT_C = OUT_C // GROUPS
150
{% endif %}
151

152
    x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
153
    w_base = (
154
        W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
155
    )
156

157
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
158

159
{% if UNROLL %}
160
{% for i in range(KERNEL_H) %}
161
{% for j in range(KERNEL_W) %}
162
    i = {{i}}
163
    j = {{j}}
164
    for k in range(0, GROUP_IN_C, BLOCK_K):
165
        """
166
    + LOOP_BODY
167
    + """
168
{% endfor %}
169
{% endfor %}
170
{% else %}
171
    # Could be simplified, but slightly slower:
172
    # for i in range(KERNEL_H):
173
    #     for j in range(KERNEL_W):
174
    #         for k in range(0, GROUP_IN_C, BLOCK_K):
175
    BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
176
    for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
177
        k = (ijk % BLOCK_K_COUNT) * BLOCK_K
178
        ij = ijk // BLOCK_K_COUNT
179
        i = ij // KERNEL_W
180
        j = ij % KERNEL_W
181
        """
182
    + LOOP_BODY
183
    + """
184
{% endif %}
185

186
    mask = (
187
        (idx_n < BATCH)[:, None]
188
        & (idx_y_h < OUT_H)[:, None]
189
        & (idx_y_w < OUT_W)[:, None]
190
        & (idx_y_c < GROUP_OUT_C)[None, :]
191
    )
192
    idx_n = idx_n[:, None]
193
    idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
194
    idx_h = idx_y_h[:, None]
195
    idx_w = idx_y_w[:, None]
196

197
    # inductor generates a suffix
198
    {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
199
""",
200
)
201

202
aten_convolution = ExternKernelChoice(
203
    torch.convolution,
204
    "at::convolution",
205
    has_out_variant=False,
206
    op_overload=aten.convolution.default,
207
)
208

209

210
def conv1x1_via_mm(x, w, *, out):
211
    w = torch.squeeze(torch.squeeze(w, -1), -1)
212
    return torch.matmul(
213
        x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
214
    )
215

216

217
aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
218

219

220
class ConvLayoutParams(TypedDict):
221
    stride: tuple[int, ...]
222
    padding: tuple[int, ...]
223
    dilation: tuple[int, ...]
224
    transposed: bool
225
    output_padding: tuple[int, ...]
226
    groups: int
227

228

229
def conv_layout(
230
    x: TensorBox,
231
    weight: TensorBox,
232
    bias: Optional[TensorBox],
233
    stride: Sequence[int],
234
    padding: tuple[int, ...],
235
    dilation: tuple[int, ...],
236
    transposed: bool,
237
    output_padding: tuple[int, ...],
238
    groups: int,
239
) -> ir.Layout:
240
    """Determine output layout for a convolution"""
241
    with V.graph.fake_mode:
242
        output = torch.ops.aten.convolution(
243
            ir.ir_node_to_tensor(x, guard_shape=True),
244
            ir.ir_node_to_tensor(weight, guard_shape=True),
245
            ir.ir_node_to_tensor(bias, guard_shape=True),
246
            stride,
247
            tuple(V.graph.sizevars.size_hint(p) for p in padding),  # type: ignore[arg-type]
248
            dilation,
249
            transposed,
250
            tuple(V.graph.sizevars.size_hint(p) for p in output_padding),  # type: ignore[arg-type]
251
            groups,
252
        )
253
        sizes = ir.convert_shape_to_inductor(output.size())
254
        stride = ir.convert_shape_to_inductor(output.stride())  # type: ignore[assignment]
255

256
    return ir.FixedLayout(
257
        x.get_device(),
258
        x.get_dtype(),
259
        sizes,
260
        stride,
261
    )
262

263

264
def channels_last_order(rank):
265
    order = list(reversed(range(rank)))
266
    order.insert(1, order.pop(-1))
267
    return order
268

269

270
def convert_1x1_conv_to_mm(x, weight, bias):
271
    # special case for 1x1 convolution, which is actually just a matmul
272
    rank = len(weight.get_size())
273
    for _ in range(rank - 2):
274
        weight = L[aten.squeeze](weight, dim=-1)
275
    weight = L[aten.permute](weight, [1, 0])
276

277
    if x.get_size()[0] != 1:
278
        x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
279
    else:
280
        x.realize()
281
        x.freeze_layout()
282

283
    x_permute = list(range(rank))
284
    x_permute.append(x_permute.pop(1))
285
    x = L[aten.permute](x, x_permute)
286
    *sizes, in_chan = x.get_size()
287
    x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
288
    if bias is None:
289
        result = L[aten.mm](x, weight)
290
    else:
291
        result = L[aten.addmm](bias, x, weight)
292
    result = L[aten.reshape](result, [*sizes, -1])
293
    result_permute = list(range(rank))
294
    result_permute.insert(1, result_permute.pop(-1))
295
    return L[aten.permute](result, result_permute)
296

297

298
@register_lowering(aten.convolution)
299
def convolution(
300
    x: TensorBox,
301
    weight: TensorBox,
302
    bias: TensorBox,
303
    stride: List[int],
304
    padding: List[int],
305
    dilation: List[int],
306
    transposed: bool,
307
    output_padding: List[int],
308
    groups: int,
309
):
310
    stride = tuple(stride)
311
    padding = tuple(padding)
312
    dilation = tuple(dilation)
313
    output_padding = tuple(output_padding)
314
    if not isinstance(groups, int):
315
        groups = V.graph.sizevars.evaluate_static_shape(groups)
316
    assert isinstance(groups, int)
317
    kwargs: ConvLayoutParams = {
318
        "stride": stride,
319
        "padding": padding,
320
        "dilation": dilation,
321
        "transposed": transposed,
322
        "output_padding": output_padding,
323
        "groups": groups,
324
    }
325

326
    if len(x.get_size()) == len(weight.get_size()) - 1:
327
        # add batch dimension to simplify rest of function
328
        return L[aten.squeeze](
329
            convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
330
            dim=0,
331
        )
332

333
    out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
334
        weight.get_size()
335
    )
336
    ndim = len(kernel_shape)
337
    stride = pad_listlike(stride, ndim)
338
    padding = pad_listlike(padding, ndim)
339
    dilation = pad_listlike(dilation, ndim)
340
    output_padding = pad_listlike(output_padding, ndim)
341

342
    def channels_last_conv():
343
        if V.graph.layout_opt and ndim == 2:
344
            return True
345

346
        layout = conv_layout(x, weight, None, **kwargs)
347
        req_stride_order = ir.get_stride_order(
348
            V.graph.sizevars.size_hints(layout.stride)
349
        )
350
        return req_stride_order == ir.NHWC_STRIDE_ORDER
351

352
    autotuning_gemm = config.max_autotune or config.max_autotune_gemm
353

354
    if (
355
        (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
356
        and is_ones(kernel_shape)
357
        and is_ones(stride)
358
        and is_zeros(padding)
359
        and is_ones(dilation)
360
        and not transposed
361
        and is_zeros(output_padding)
362
        and groups == 1
363
    ):
364
        return convert_1x1_conv_to_mm(x, weight, bias)
365

366
    if bias is not None and ir.get_device_type(x) != "cpu":
367
        # peel off the bias, cudnn is slower with it
368
        result = convolution(x, weight, None, **kwargs)
369
        return L[aten.add](
370
            result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
371
        )
372

373
    x.realize()
374
    weight.realize()
375

376
    # ndim can be 1 for convolution in models such as demucs
377
    # TODO: check if it's beneficial to convert Conv1d to Conv2d and then
378
    # apply channels last.
379
    if V.graph.layout_opt and ndim == 2:
380
        V.graph.num_channels_last_conv += 1
381
        x = ir.ExternKernel.require_channels_last(x)
382
        # TODO maybe we can convert weights to channels last just once before
383
        # running the model.
384
        weight = ir.ExternKernel.require_channels_last(weight)
385
        layout = conv_layout(x, weight, None, **kwargs)
386
    else:
387
        layout = conv_layout(x, weight, None, **kwargs)
388
        req_stride_order = ir.get_stride_order(
389
            V.graph.sizevars.size_hints(layout.stride)
390
        )
391
        x = ir.ExternKernel.require_stride_order(x, req_stride_order)
392
        weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
393

394
    ordered_kwargs_for_cpp_kernel = [
395
        "stride",
396
        "padding",
397
        "dilation",
398
        "transposed",
399
        "output_padding",
400
        "groups",
401
    ]
402
    if bias is None:
403
        args = [x, weight]
404
        kwargs["bias"] = None  # type: ignore[typeddict-unknown-key]
405
        ordered_kwargs_for_cpp_kernel.insert(0, "bias")
406
    else:
407
        args = [x, weight, bias]
408
        bias.realize()
409
        bias.freeze_layout()
410
        V.graph.sizevars.evaluate_static_shapes(bias.get_size())
411
    choices = [
412
        aten_convolution.bind(
413
            args,
414
            layout,
415
            ordered_kwargs_for_cpp_kernel,
416
            **kwargs,
417
        )
418
    ]
419

420
    if (
421
        use_triton_template(layout)
422
        # templates only support these:
423
        and ndim == 2
424
        and is_ones(dilation)
425
        and not transposed
426
        and is_zeros(output_padding)
427
        # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
428
        and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1])  # type: ignore[arg-type]
429
    ):
430
        if (
431
            is_ones(kernel_shape)
432
            and is_ones(stride)
433
            and is_zeros(padding)
434
            and groups == 1
435
        ):
436
            choices.append(aten_conv1x1_via_mm.bind(args, layout))
437

438
        for cfg in conv_configs(
439
            sympy_product([x.get_size()[0], *x.get_size()[2:]]),
440
            out_chan,
441
            in_chan,
442
        ):
443
            conv2d_template.maybe_append_choice(
444
                choices,
445
                input_nodes=(x, weight),
446
                layout=layout,
447
                KERNEL_H=kernel_shape[0],
448
                KERNEL_W=kernel_shape[1],
449
                STRIDE_H=stride[0],
450
                STRIDE_W=stride[1],
451
                PADDING_H=padding[0],
452
                PADDING_W=padding[1],
453
                GROUPS=groups,
454
                # TODO(jansel): try unroll for bigger kernels once fixed:
455
                #               https://github.com/openai/triton/issues/1254
456
                UNROLL=is_ones(kernel_shape),
457
                ALLOW_TF32=torch.backends.cudnn.allow_tf32,
458
                num_stages=cfg.num_stages,
459
                num_warps=cfg.num_warps,
460
                **cfg.kwargs,
461
            )
462

463
    return autotune_select_algorithm("convolution", choices, args, layout)
464

465

466
@register_lowering(aten._convolution)
467
def _convolution(
468
    x,
469
    weight,
470
    bias,
471
    stride,
472
    padding,
473
    dilation,
474
    transposed,
475
    output_padding,
476
    groups,
477
    benchmark,
478
    deterministic,
479
    cudnn_enabled,
480
    allow_tf32,
481
):
482
    return convolution(
483
        x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
484
    )
485

486

487
def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
488
    assert fx_node.target == torch.ops.aten.convolution.default
489
    if V.graph.layout_opt:
490
        return args, kwargs
491
    else:
492
        return constrain_to_fx_strides(fx_node, *args, **kwargs)
493

494

495
add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
496

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.