1
from __future__ import annotations
5
from typing import cast, List, Optional, Sequence, Tuple, TypedDict
8
from .. import config, ir
9
from ..ir import TensorBox
11
from ..lowering import (
12
add_layout_constraint,
13
constrain_to_fx_strides,
17
from ..select_algorithm import (
18
autotune_select_algorithm,
30
from ..virtualized import V
31
from .mm_common import filtered_configs
33
log = logging.getLogger(__name__)
39
def conv_grid(n, c, h, w, meta):
41
ceildiv(n * h * w, meta["BLOCK_M"]),
42
ceildiv(c, meta["BLOCK_N"]),
47
# List of dictionaries to store the kernel configs. Configs that evaluate to true
48
# will be utilised on the target platform
50
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
51
{"config": (64, 256, 16, 2, 4), "cond": True},
52
{"config": (256, 64, 16, 2, 4), "cond": True},
53
{"config": (1024, 16, 16, 1, 8), "cond": True},
54
{"config": (128, 128, 32, 2, 8), "cond": True},
55
{"config": (64, 64, 32, 2, 4), "cond": True},
56
{"config": (64, 256, 32, 2, 8), "cond": True},
57
{"config": (256, 64, 32, 2, 8), "cond": True},
60
# Create filtered list of configs based on conv
61
platform_configs = tuple(
62
cast(Tuple[int, int, int, int, int], config["config"])
63
for config in kernel_configs
67
# On ROCm convert num_stages to 1 as pipelining provides no benefit
69
platform_configs = tuple(
70
(config[0], config[1], config[2], 1, config[4]) for config in platform_configs
73
conv_configs = functools.partial(
75
configs=platform_configs,
79
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
80
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
81
idx_x_c = tl.arange(0, BLOCK_K) + k
84
(idx_x_h * stride_xh)[:, None]
85
+ (idx_x_w * stride_xw)[:, None]
86
+ (idx_x_c * stride_xc)[None, :]
89
(idx_n < BATCH)[:, None]
90
& (idx_x_h >= 0)[:, None]
91
& (idx_x_h < IN_H)[:, None]
92
& (idx_x_w >= 0)[:, None]
93
& (idx_x_w < IN_W)[:, None]
94
& (idx_x_c < GROUP_IN_C)[None, :]
96
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
99
(idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
101
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
102
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
103
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
107
This is a relatively simple conv implementation that can likely be
108
improved. Many alternate conv versions can be found here:
109
https://github.com/pytorch/torchdynamo/pull/971
111
conv2d_template = TritonTemplate(
115
{{def_kernel("X", "W")}}
117
BATCH = {{size("X", 0)}}
118
IN_C = {{size("X", 1)}}
119
IN_H = {{size("X", 2)}}
120
IN_W = {{size("X", 3)}}
121
OUT_C = {{size(None, 1)}}
122
OUT_H = {{size(None, 2)}}
123
OUT_W = {{size(None, 3)}}
126
stride_xn = {{stride("X", 0)}}
127
stride_xc = {{stride("X", 1)}}
128
stride_xh = {{stride("X", 2)}}
129
stride_xw = {{stride("X", 3)}}
130
stride_wc_out = {{stride("W", 0)}}
131
stride_wc_in = {{stride("W", 1)}}
132
stride_wh = {{stride("W", 2)}}
133
stride_ww = {{stride("W", 3)}}
135
nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
136
idx_y_w = nhw % OUT_W
140
idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
147
group = tl.program_id(2)
148
GROUP_IN_C = IN_C // GROUPS
149
GROUP_OUT_C = OUT_C // GROUPS
152
x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
154
W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
157
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
160
{% for i in range(KERNEL_H) %}
161
{% for j in range(KERNEL_W) %}
164
for k in range(0, GROUP_IN_C, BLOCK_K):
171
# Could be simplified, but slightly slower:
172
# for i in range(KERNEL_H):
173
# for j in range(KERNEL_W):
174
# for k in range(0, GROUP_IN_C, BLOCK_K):
175
BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
176
for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
177
k = (ijk % BLOCK_K_COUNT) * BLOCK_K
178
ij = ijk // BLOCK_K_COUNT
187
(idx_n < BATCH)[:, None]
188
& (idx_y_h < OUT_H)[:, None]
189
& (idx_y_w < OUT_W)[:, None]
190
& (idx_y_c < GROUP_OUT_C)[None, :]
192
idx_n = idx_n[:, None]
193
idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
194
idx_h = idx_y_h[:, None]
195
idx_w = idx_y_w[:, None]
197
# inductor generates a suffix
198
{{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
202
aten_convolution = ExternKernelChoice(
205
has_out_variant=False,
206
op_overload=aten.convolution.default,
210
def conv1x1_via_mm(x, w, *, out):
211
w = torch.squeeze(torch.squeeze(w, -1), -1)
213
x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
217
aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
220
class ConvLayoutParams(TypedDict):
221
stride: tuple[int, ...]
222
padding: tuple[int, ...]
223
dilation: tuple[int, ...]
225
output_padding: tuple[int, ...]
232
bias: Optional[TensorBox],
233
stride: Sequence[int],
234
padding: tuple[int, ...],
235
dilation: tuple[int, ...],
237
output_padding: tuple[int, ...],
240
"""Determine output layout for a convolution"""
241
with V.graph.fake_mode:
242
output = torch.ops.aten.convolution(
243
ir.ir_node_to_tensor(x, guard_shape=True),
244
ir.ir_node_to_tensor(weight, guard_shape=True),
245
ir.ir_node_to_tensor(bias, guard_shape=True),
247
tuple(V.graph.sizevars.size_hint(p) for p in padding), # type: ignore[arg-type]
250
tuple(V.graph.sizevars.size_hint(p) for p in output_padding), # type: ignore[arg-type]
253
sizes = ir.convert_shape_to_inductor(output.size())
254
stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
256
return ir.FixedLayout(
264
def channels_last_order(rank):
265
order = list(reversed(range(rank)))
266
order.insert(1, order.pop(-1))
270
def convert_1x1_conv_to_mm(x, weight, bias):
271
# special case for 1x1 convolution, which is actually just a matmul
272
rank = len(weight.get_size())
273
for _ in range(rank - 2):
274
weight = L[aten.squeeze](weight, dim=-1)
275
weight = L[aten.permute](weight, [1, 0])
277
if x.get_size()[0] != 1:
278
x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
283
x_permute = list(range(rank))
284
x_permute.append(x_permute.pop(1))
285
x = L[aten.permute](x, x_permute)
286
*sizes, in_chan = x.get_size()
287
x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
289
result = L[aten.mm](x, weight)
291
result = L[aten.addmm](bias, x, weight)
292
result = L[aten.reshape](result, [*sizes, -1])
293
result_permute = list(range(rank))
294
result_permute.insert(1, result_permute.pop(-1))
295
return L[aten.permute](result, result_permute)
298
@register_lowering(aten.convolution)
307
output_padding: List[int],
310
stride = tuple(stride)
311
padding = tuple(padding)
312
dilation = tuple(dilation)
313
output_padding = tuple(output_padding)
314
if not isinstance(groups, int):
315
groups = V.graph.sizevars.evaluate_static_shape(groups)
316
assert isinstance(groups, int)
317
kwargs: ConvLayoutParams = {
320
"dilation": dilation,
321
"transposed": transposed,
322
"output_padding": output_padding,
326
if len(x.get_size()) == len(weight.get_size()) - 1:
327
# add batch dimension to simplify rest of function
328
return L[aten.squeeze](
329
convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
333
out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
336
ndim = len(kernel_shape)
337
stride = pad_listlike(stride, ndim)
338
padding = pad_listlike(padding, ndim)
339
dilation = pad_listlike(dilation, ndim)
340
output_padding = pad_listlike(output_padding, ndim)
342
def channels_last_conv():
343
if V.graph.layout_opt and ndim == 2:
346
layout = conv_layout(x, weight, None, **kwargs)
347
req_stride_order = ir.get_stride_order(
348
V.graph.sizevars.size_hints(layout.stride)
350
return req_stride_order == ir.NHWC_STRIDE_ORDER
352
autotuning_gemm = config.max_autotune or config.max_autotune_gemm
355
(config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
356
and is_ones(kernel_shape)
358
and is_zeros(padding)
359
and is_ones(dilation)
361
and is_zeros(output_padding)
364
return convert_1x1_conv_to_mm(x, weight, bias)
366
if bias is not None and ir.get_device_type(x) != "cpu":
367
# peel off the bias, cudnn is slower with it
368
result = convolution(x, weight, None, **kwargs)
370
result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
376
# ndim can be 1 for convolution in models such as demucs
377
# TODO: check if it's beneficial to convert Conv1d to Conv2d and then
378
# apply channels last.
379
if V.graph.layout_opt and ndim == 2:
380
V.graph.num_channels_last_conv += 1
381
x = ir.ExternKernel.require_channels_last(x)
382
# TODO maybe we can convert weights to channels last just once before
384
weight = ir.ExternKernel.require_channels_last(weight)
385
layout = conv_layout(x, weight, None, **kwargs)
387
layout = conv_layout(x, weight, None, **kwargs)
388
req_stride_order = ir.get_stride_order(
389
V.graph.sizevars.size_hints(layout.stride)
391
x = ir.ExternKernel.require_stride_order(x, req_stride_order)
392
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
394
ordered_kwargs_for_cpp_kernel = [
404
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
405
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
407
args = [x, weight, bias]
410
V.graph.sizevars.evaluate_static_shapes(bias.get_size())
412
aten_convolution.bind(
415
ordered_kwargs_for_cpp_kernel,
421
use_triton_template(layout)
422
# templates only support these:
424
and is_ones(dilation)
426
and is_zeros(output_padding)
427
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
428
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
431
is_ones(kernel_shape)
433
and is_zeros(padding)
436
choices.append(aten_conv1x1_via_mm.bind(args, layout))
438
for cfg in conv_configs(
439
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
443
conv2d_template.maybe_append_choice(
445
input_nodes=(x, weight),
447
KERNEL_H=kernel_shape[0],
448
KERNEL_W=kernel_shape[1],
451
PADDING_H=padding[0],
452
PADDING_W=padding[1],
454
# TODO(jansel): try unroll for bigger kernels once fixed:
455
# https://github.com/openai/triton/issues/1254
456
UNROLL=is_ones(kernel_shape),
457
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
458
num_stages=cfg.num_stages,
459
num_warps=cfg.num_warps,
463
return autotune_select_algorithm("convolution", choices, args, layout)
466
@register_lowering(aten._convolution)
483
x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
487
def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
488
assert fx_node.target == torch.ops.aten.convolution.default
489
if V.graph.layout_opt:
492
return constrain_to_fx_strides(fx_node, *args, **kwargs)
495
add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)