6
constant ushort ushort_arg_0[[function_constant(0)]];
7
constant ushort ushort_arg_1[[function_constant(1)]];
8
constant ushort ushort_arg_2[[function_constant(2)]];
9
constant ushort ushort_arg_3[[function_constant(3)]];
10
constant ushort ushort_arg_4[[function_constant(4)]];
11
constant ushort ushort_arg_5[[function_constant(5)]];
12
constant ushort ushort_arg_6[[function_constant(6)]];
13
constant ushort ushort_arg_7[[function_constant(7)]];
14
constant ushort ushort_arg_8[[function_constant(8)]];
15
constant ushort ushort_arg_9[[function_constant(9)]];
17
inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
19
kernel void affine(constant half4* scale[[buffer(0)]],
20
constant half4* shift[[buffer(1)]],
21
texture2d_array<half, access::read> in[[texture(0)]],
22
texture2d_array<half, access::write> out[[texture(1)]],
23
ushort3 gid[[thread_position_in_grid]]) {
24
const ushort C = ushort_arg_0;
25
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
28
const half4 scale_c = scale[gid.z % divRoundUp(C, 4)];
29
const half4 shift_c = shift[gid.z % divRoundUp(C, 4)];
30
ushort2 gid_(gid.x, gid.y);
31
const half4 x = in.read(gid_, gid.z);
32
const half4 y = scale_c * x + shift_c;
33
out.write(y, gid_, gid.z);
36
kernel void affine_nonarray(constant half4* scale[[buffer(0)]],
37
constant half4* shift[[buffer(1)]],
38
texture2d<half, access::read> in[[texture(0)]],
39
texture2d<half, access::write> out[[texture(1)]],
40
ushort2 gid[[thread_position_in_grid]]) {
41
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
44
const half4 scale_c = scale[0];
45
const half4 shift_c = shift[0];
46
half4 x = in.read(gid);
47
const half4 y = scale_c * x + shift_c;
51
kernel void prelu_nonshared(constant half4* weights[[buffer(0)]],
52
texture2d_array<half, access::read> in[[texture(0)]],
53
texture2d_array<half, access::write> out[[texture(1)]],
54
ushort3 gid[[thread_position_in_grid]]) {
55
const ushort C = ushort_arg_0;
56
const ushort S = ushort_arg_1;
57
const bool channel_shared = S == 1;
58
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
61
half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
62
: weights[gid.z % divRoundUp(C, 4)];
63
ushort2 gid_(gid.x, gid.y);
64
half4 x = in.read(gid_, gid.z);
65
half4 y = select(x * w, x, x > 0.0h);
66
out.write(y, gid_, gid.z);
69
kernel void prelu_nonshared_nonarray(constant half4* weights[[buffer(0)]],
70
texture2d<half, access::read> in[[texture(0)]],
71
texture2d<half, access::write> out[[texture(1)]],
72
ushort2 gid[[thread_position_in_grid]]) {
73
// const ushort C = ushort_arg_0;
74
const ushort S = ushort_arg_1;
75
const bool channel_shared = S == 1;
76
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
79
half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
81
half4 x = in.read(gid);
82
half4 y = select(x * w, x, x > 0.0h);
86
// One block per texture.
87
// 256 threads per block.
90
constant const bool instance_norm_has_prelu = ushort_arg_1 > 0;
92
kernel void instance_norm(
93
constant half4* weights[[buffer(0)]],
94
constant half4* bias[[buffer(1)]],
95
constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
96
texture2d_array<half, access::read> in[[texture(0)]],
97
texture2d_array<half, access::write> out[[texture(1)]],
98
ushort3 gid[[thread_position_in_grid]],
99
ushort tid[[thread_index_in_threadgroup]],
100
ushort3 tcount[[threads_per_threadgroup]]) {
101
if (gid.z >= out.get_array_size()) {
104
const ushort C = ushort_arg_0;
105
const ushort S = ushort_arg_1;
106
const bool channel_shared = S == 1;
107
const ushort c = gid.z % divRoundUp(C, 4);
108
constexpr ushort THREADGROUP_SIZE = 256;
110
threadgroup AccT per_thread_state[THREADGROUP_SIZE];
111
// Each block handles a single texture.
112
per_thread_state[tid] = 0;
113
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
114
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
115
per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y), gid.z));
119
threadgroup_barrier(mem_flags::mem_threadgroup);
121
// 256 -> 32 reduction
123
per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
124
per_thread_state[tid + 96] + per_thread_state[tid + 128] +
125
per_thread_state[tid + 160] + per_thread_state[tid + 192] +
126
per_thread_state[tid + 224];
129
threadgroup_barrier(mem_flags::mem_threadgroup);
133
for (ushort i = 0; i < 32; ++i) {
134
sum += per_thread_state[i];
136
sum /= (in.get_width() * in.get_height());
137
per_thread_state[0] = sum;
139
threadgroup_barrier(mem_flags::mem_threadgroup);
140
// Broadcast to all threads.
141
const AccT mean = per_thread_state[0];
143
threadgroup_barrier(mem_flags::mem_threadgroup);
145
per_thread_state[tid] = 0;
146
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
147
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
148
AccT delta = static_cast<AccT>(in.read(ushort2(x, y), gid.z)) - mean;
149
per_thread_state[tid] += delta * delta;
153
threadgroup_barrier(mem_flags::mem_threadgroup);
155
// 256 -> 32 reduction
157
per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
158
per_thread_state[tid + 96] + per_thread_state[tid + 128] +
159
per_thread_state[tid + 160] + per_thread_state[tid + 192] +
160
per_thread_state[tid + 224];
163
threadgroup_barrier(mem_flags::mem_threadgroup);
167
for (ushort i = 0; i < 32; ++i) {
168
sum += per_thread_state[i];
170
sum /= (in.get_width() * in.get_height());
171
per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
174
threadgroup_barrier(mem_flags::mem_threadgroup);
175
// Broadcast to all threads.
176
const AccT inv_var = per_thread_state[0];
178
const AccT c_weights = static_cast<AccT>(weights[c]);
179
const AccT c_bias = static_cast<AccT>(bias[c]);
181
const AccT scale = inv_var * c_weights;
182
const AccT shift = c_bias - mean * scale;
185
if (instance_norm_has_prelu) {
186
w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[c];
188
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
189
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
191
static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y), gid.z)) * scale + shift);
192
if (instance_norm_has_prelu) {
193
scaled = select(scaled * w, scaled, scaled > 0.0h);
195
out.write(scaled, ushort2(x, y), gid.z);
200
// One block per texture.
201
// 256 threads per block.
202
kernel void instance_norm_nonarray(
203
constant half4* weights[[buffer(0)]],
204
constant half4* bias[[buffer(1)]],
205
constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
206
texture2d<half, access::read> in[[texture(0)]],
207
texture2d<half, access::write> out[[texture(1)]],
208
ushort3 gid[[thread_position_in_grid]],
209
ushort tid[[thread_index_in_threadgroup]],
210
ushort3 tcount[[threads_per_threadgroup]]) {
211
// const ushort C = ushort_arg_0;
212
const ushort S = ushort_arg_1;
213
const bool channel_shared = S == 1;
214
constexpr ushort THREADGROUP_SIZE = 256;
216
threadgroup AccT per_thread_state[THREADGROUP_SIZE];
217
// Each block handles a single texture.
218
per_thread_state[tid] = 0;
219
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
220
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
221
per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y)));
225
threadgroup_barrier(mem_flags::mem_threadgroup);
227
// 256 -> 32 reduction
229
per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
230
per_thread_state[tid + 96] + per_thread_state[tid + 128] +
231
per_thread_state[tid + 160] + per_thread_state[tid + 192] +
232
per_thread_state[tid + 224];
235
threadgroup_barrier(mem_flags::mem_threadgroup);
239
for (ushort i = 0; i < 32; ++i) {
240
sum += per_thread_state[i];
242
sum /= (in.get_width() * in.get_height());
243
per_thread_state[0] = sum;
245
threadgroup_barrier(mem_flags::mem_threadgroup);
246
// Broadcast to all threads.
247
const AccT mean = per_thread_state[0];
249
threadgroup_barrier(mem_flags::mem_threadgroup);
251
per_thread_state[tid] = 0;
252
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
253
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
254
AccT delta = static_cast<AccT>(in.read(ushort2(x, y))) - mean;
255
per_thread_state[tid] += delta * delta;
259
threadgroup_barrier(mem_flags::mem_threadgroup);
261
// 256 -> 32 reduction
263
per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
264
per_thread_state[tid + 96] + per_thread_state[tid + 128] +
265
per_thread_state[tid + 160] + per_thread_state[tid + 192] +
266
per_thread_state[tid + 224];
269
threadgroup_barrier(mem_flags::mem_threadgroup);
273
for (ushort i = 0; i < 32; ++i) {
274
sum += per_thread_state[i];
276
sum /= (in.get_width() * in.get_height());
277
per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
280
threadgroup_barrier(mem_flags::mem_threadgroup);
281
// Broadcast to all threads.
282
const AccT inv_var = per_thread_state[0];
284
const AccT c_weights = static_cast<AccT>(weights[0]);
285
const AccT c_bias = static_cast<AccT>(bias[0]);
287
const AccT scale = inv_var * c_weights;
288
const AccT shift = c_bias - mean * scale;
291
if (instance_norm_has_prelu) {
292
w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[0];
294
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
295
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
296
half4 scaled = static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y))) * scale + shift);
297
if (instance_norm_has_prelu) {
298
scaled = select(scaled * w, scaled, scaled > 0.0h);
300
out.write(scaled, ushort2(x, y));
305
kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
306
texture2d_array<half, access::write> out[[texture(0)]],
307
ushort3 gid[[thread_position_in_grid]]) {
308
const ushort C = ushort_arg_0;
309
const ushort H = ushort_arg_1;
310
const ushort W = ushort_arg_2;
311
if (gid.x >= W || gid.y >= H) {
315
const ushort n = gid.z / divRoundUp(C, 4);
316
const ushort c = gid.z - n * divRoundUp(C, 4);
318
// TODO: are the `else` branches needed?
319
// TODO: trick the optimizer for case where C == 4?
320
#define CHW_TO_CHWP4(idx, n, c_, h, w) \
322
trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \
328
CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x);
329
CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x);
330
CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x);
331
CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x);
334
out.write(trns, gid.xy, gid.z);
337
kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]],
338
texture2d<half, access::write> out[[texture(0)]],
339
ushort2 gid[[thread_position_in_grid]]) {
340
const ushort C = ushort_arg_0;
341
const ushort H = ushort_arg_1;
342
const ushort W = ushort_arg_2;
344
if (gid.x >= W || gid.y >= H) {
349
// TODO: are the `else` branches needed?
350
// TODO: trick the optimizer for case where C % 4 == 0?
352
#define CHW_TO_CHWP4(idx, c, h, w) \
354
trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \
359
CHW_TO_CHWP4(0, 0, gid.y, gid.x);
360
CHW_TO_CHWP4(1, 1, gid.y, gid.x);
361
CHW_TO_CHWP4(2, 2, gid.y, gid.x);
362
CHW_TO_CHWP4(3, 3, gid.y, gid.x);
365
out.write(trns, gid.xy);
368
kernel void copy_metal_to_nchw(texture2d_array<half, access::read> in[[texture(0)]],
369
device float* out[[buffer(0)]],
370
ushort3 gid[[thread_position_in_grid]]) {
371
const ushort C = ushort_arg_0;
372
const ushort H = ushort_arg_1;
373
const ushort W = ushort_arg_2;
375
if (gid.x >= W || gid.y >= H) {
378
const ushort n = gid.z / divRoundUp(C, 4);
379
const ushort c = gid.z - n * divRoundUp(C, 4);
381
half4 cs = in.read(gid.xy, gid.z);
383
#define CHWP4_TO_CHW(idx, n, c_, h, w) \
385
out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \
388
CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x);
389
CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x);
390
CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x);
391
CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x);
395
kernel void copy_metal_to_nchw_nonarray(texture2d<half, access::read> in[[texture(0)]],
396
device float* out[[buffer(0)]],
397
ushort2 gid[[thread_position_in_grid]]) {
398
const ushort C = ushort_arg_0;
399
const ushort H = ushort_arg_1;
400
const ushort W = ushort_arg_2;
402
if (gid.x >= W || gid.y >= H) {
406
half4 cs = in.read(gid.xy);
408
#define CHWP4_TO_CHW(idx, c, h, w) \
410
out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \
413
CHWP4_TO_CHW(0, 0, gid.y, gid.x);
414
CHWP4_TO_CHW(1, 1, gid.y, gid.x);
415
CHWP4_TO_CHW(2, 2, gid.y, gid.x);
416
CHWP4_TO_CHW(3, 3, gid.y, gid.x);
420
kernel void convtranspose_upscale(texture2d_array<half, access::read> in[[texture(0)]],
421
texture2d_array<half, access::write> out[[texture(1)]],
422
ushort3 gid[[thread_position_in_grid]]) {
423
// All resolved at compile time.
424
// Assume symmetric kernel/stride/pad for now.
425
const ushort kernel_ = ushort_arg_0;
426
const ushort stride = ushort_arg_1;
427
const ushort pad = ushort_arg_2;
429
half4 zero(0.0h, 0.0h, 0.0h, 0.0h);
431
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
434
const ushort2 gid_ = gid.xy;
435
if (gid.x < kernel_ - 1 - pad || gid.y < kernel_ - 1 - pad) {
436
out.write(zero, gid_, gid.z);
440
if (((gid.x - (kernel_ - 1 - pad)) % stride == 0) &&
441
((gid.y - (kernel_ - 1 - pad)) % stride == 0)) {
442
ushort2 in_pos((gid.x - (kernel_ - 1 - pad)) / stride, (gid.y - (kernel_ - 1 - pad)) / stride);
444
if (in_pos.x < in.get_width() && in_pos.y < in.get_height()) {
445
half4 input = in.read(in_pos, gid.z);
446
out.write(input, gid_, gid.z);
448
out.write(zero, gid_, gid.z);
451
out.write(zero, gid_, gid.z);
455
constant bool has_in_arr = (ushort_arg_7 > 1 || ushort_arg_0 * ushort_arg_1 * ushort_arg_6 > 4);
456
constant bool has_out_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
457
constant bool has_in_tex = (!has_in_arr);
458
constant bool has_out_tex = (!has_out_arr);
461
texture2d_array<half, access::read> ina[[ texture(0), function_constant(has_in_arr) ]],
462
texture2d<half, access::read> in[[ texture(0), function_constant(has_in_tex) ]],
463
texture2d_array<half, access::write> outa[[ texture(1), function_constant(has_out_arr) ]],
464
texture2d<half, access::write> out[[ texture(1), function_constant(has_out_tex) ]],
465
constant half4* bias[[buffer(0)]],
466
ushort3 gid[[thread_position_in_grid]]) {
468
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
472
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
476
const ushort kernel_h = ushort_arg_0;
477
const ushort kernel_w = ushort_arg_1;
478
const ushort stride_h = ushort_arg_2;
479
const ushort stride_w = ushort_arg_3;
480
const ushort pad_l = ushort_arg_4;
481
const ushort pad_t = ushort_arg_5;
482
const ushort C = ushort_arg_6;
483
// const int N = ushort_arg_7;
484
const ushort height_col = ushort_arg_8; //(outa.get_height() + pad + pad - kernel_) / stride + 1;
485
const ushort width_col = ushort_arg_9; // (outa.get_width() + pad + pad - kernel_) / stride + 1;
487
const ushort n = gid.z / divRoundUp(C, 4);
488
const ushort c = gid.z - n * divRoundUp(C, 4);
490
const ushort w = gid.x + pad_l;
491
const ushort h = gid.y + pad_t;
493
// compute the start and end of the output
494
const ushort w_col_start = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
495
const ushort w_col_end = min(ushort(w / stride_w + 1), ushort(width_col));
496
const ushort h_col_start = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
497
const ushort h_col_end = min(ushort(h / stride_h + 1), ushort(height_col));
499
float4 val = static_cast<float4>(bias[c]);
500
for (ushort h_col = h_col_start; h_col < h_col_end; ++h_col) {
501
for (ushort w_col = w_col_start; w_col < w_col_end; ++w_col) {
502
const ushort w_k = w - w_col * stride_w;
503
const ushort h_k = h - h_col * stride_h;
505
// layout is essentially: [N][K][K][C][H][W]
506
// - where the divRoundUp(K * K * C, 4) channels are interleaved as usual.
507
// Thus, it's actually [N][divRoundUp(K * K * C, 4)][H][W].
509
// If C % 4 is not zero, then we have to play some games via partial indexing.
510
// TODO: is it worth optimizing this loop via padding in C?
512
ushort c_col = n * kernel_h * kernel_w * divRoundUp(C, 4) +
513
h_k * kernel_w * divRoundUp(C, 4) + w_k * divRoundUp(C, 4) + c;
515
val += static_cast<float4>(ina.read(ushort2(w_col, h_col), c_col));
518
val += static_cast<float4>(in.read(ushort2(w_col, h_col), c_col));
521
half4 components(0, 0, 0, 0);
522
for (auto i = 0; i < 4; ++i) {
523
ushort c_col_i = n * divRoundUp(kernel_h * kernel_w * C, 4) * 4 + h_k * kernel_w * C +
525
ushort c_col_i_z = c_col_i / 4;
526
ushort c_col_i_off = c_col_i - c_col_i_z * 4;
528
components[i] = ina.read(ushort2(w_col, h_col), c_col_i_z)[c_col_i_off];
531
components[i] = in.read(ushort2(w_col, h_col))[c_col_i_off];
534
val += static_cast<float4>(components);
539
outa.write(static_cast<half4>(val), gid.xy, gid.z);
542
out.write(static_cast<half4>(val), gid.xy);
546
kernel void preprocess_stylizer(device uchar4* in[[buffer(0)]],
547
constant half* mean[[buffer(1)]],
548
constant half4* noise[[buffer(2)]],
549
texture2d<half, access::write> out[[texture(0)]],
550
ushort2 gid[[thread_position_in_grid]]) {
552
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
555
const ushort noise_size = ushort_arg_0;
557
half4 mean_half(mean[0], mean[1], mean[2], 0.0h);
558
uint input_noise_idx = ((uint)out.get_width() * (uint)gid.y + (uint)gid.x) % (noise_size / 4);
559
const half4 input_noise = noise[input_noise_idx];
560
const uint W = out.get_width();
561
#define in_at(h, w) in[(uint)(h)*W + (uint)(w)]
562
uchar4 input = in_at(gid.y, gid.x);
564
half4 input_half = static_cast<half4>(input);
565
out.write(input_half - mean_half + input_noise, gid);
568
kernel void deprocess_stylizer(texture2d<half, access::read> in[[texture(0)]],
569
device uchar4* out[[buffer(0)]],
570
constant half* mean[[buffer(1)]],
571
ushort2 gid[[thread_position_in_grid]]) {
572
if (gid.x >= in.get_width() || gid.y >= in.get_height()) {
576
half4 value = in.read(gid);
578
half4 mean_h(mean[0], mean[1], mean[2], 0.0h);
579
half4 min_h(0.0h, 0.0h, 0.0h, 255.0h);
580
half4 max_h(255.0h, 255.0h, 255.0h, 255.0h);
581
half4 clamped = clamp(value + mean_h, min_h, max_h);
582
const uint W = in.get_width();
583
#define out_at(h, w, v) out[(uint)(h)*W + (uint)(w)] = (v)
584
out_at(gid.y, gid.x, static_cast<uchar4>(clamped));
588
kernel void reflection_padding_nonarray(texture2d<half, access::read> in[[texture(0)]],
589
texture2d<half, access::write> out[[texture(1)]],
590
ushort2 gid[[thread_position_in_grid]]) {
591
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
594
ushort H = in.get_height();
595
ushort PH = out.get_height();
597
// Note: we assume symmetric padding on H/W here, which is verified
598
// in the calling code.
599
ushort pad_h = (PH - H) / 2;
600
ushort W = in.get_width();
601
ushort PW = out.get_width();
602
ushort pad_w = (PW - W) / 2;
604
short h = short(gid.y) - short(pad_h);
605
h = max(h, short(-h));
606
h = min(h, short(2 * H - h - 2));
608
short w = short(gid.x) - short(pad_w);
609
w = max(w, short(-w));
610
w = min(w, short(2 * W - w - 2));
613
out.write(in.read(inid), gid);
616
kernel void reflection_padding(texture2d_array<half, access::read> in[[texture(0)]],
617
texture2d_array<half, access::write> out[[texture(1)]],
618
ushort3 gid[[thread_position_in_grid]]) {
619
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
622
ushort H = in.get_height();
623
ushort PH = out.get_height();
625
// Note: we assume symmetric padding on H/W here, which is verified
626
// in the calling code.
627
ushort pad_h = (PH - H) / 2;
628
ushort W = in.get_width();
629
ushort PW = out.get_width();
630
ushort pad_w = (PW - W) / 2;
632
short h = short(gid.y) - short(pad_h);
633
h = max(h, short(-h));
634
h = min(h, short(2 * H - h - 2));
636
short w = short(gid.x) - short(pad_w);
637
w = max(w, short(-w));
638
w = min(w, short(2 * W - w - 2));
642
out.write(in.read(inid, gid.z), gid.xy, gid.z);
645
kernel void bilinear_upsample(texture2d<half, access::sample> in[[texture(0)]],
646
texture2d<half, access::write> out[[texture(1)]],
647
ushort2 gid[[thread_position_in_grid]]) {
648
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
651
ushort2 src = gid / 2;
652
constexpr sampler sampler(address::clamp_to_edge, filter::linear, coord::pixel);
653
half4 value = in.sample(sampler, static_cast<float2>(src));
654
out.write(value, gid);
657
constant bool in0_is_tex = ushort_arg_0 <= 1 && ushort_arg_1 <= 4;
658
constant bool in0_is_arr = !in0_is_tex;
660
kernel void elementwise_mul(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
661
texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
662
texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
663
texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
664
constant float* in1[[buffer(1)]],
665
ushort3 gid[[thread_position_in_grid]]) {
666
ushort last_dim = ushort_arg_2;
669
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
672
idx = gid.y * out.get_width() + gid.x;
674
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
677
idx = gid.y * outa.get_width() + gid.x;
679
ushort2 gid_ = gid.xy;
681
out.write(in0.read(gid_) * in1[idx % last_dim], gid_);
683
outa.write(ina0.read(gid_, gid.z) * in1[idx % last_dim], gid_, gid.z);
687
kernel void elementwise_sub(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
688
texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
689
texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
690
texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
691
constant float* in1[[buffer(1)]],
692
ushort3 gid[[thread_position_in_grid]]) {
693
ushort last_dim = ushort_arg_2;
696
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
699
idx = gid.y * out.get_width() + gid.x;
701
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
704
idx = gid.y * outa.get_width() + gid.x;
706
ushort2 gid_ = gid.xy;
708
out.write(in0.read(gid_) - in1[idx % last_dim], gid_);
710
outa.write(ina0.read(gid_, gid.z) - in1[idx % last_dim], gid_, gid.z);
715
kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
716
texture2d<half, access::read> in1[[texture(1)]],
717
texture2d<half, access::write> out[[texture(2)]],
718
ushort2 gid[[thread_position_in_grid]]) {
719
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
722
out.write(in0.read(gid) + in1.read(gid), gid);
725
kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
726
texture2d_array<half, access::read> in1[[texture(1)]],
727
texture2d_array<half, access::write> out[[texture(2)]],
728
ushort3 gid[[thread_position_in_grid]]) {
729
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
732
ushort2 gid_ = gid.xy;
733
out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z);
736
constant bool has_in0_arg = (ushort_arg_0 > 0);
737
constant bool has_in1_arg = (ushort_arg_1 > 0);
738
constant bool has_in2_arg = (ushort_arg_2 > 0);
739
constant bool has_in3_arg = (ushort_arg_3 > 0);
741
constant bool has_in0_tex = (has_in0_arg && ushort_arg_0 <= 4 && ushort_arg_4 <= 1);
742
constant bool has_in1_tex = (has_in1_arg && ushort_arg_1 <= 4 && ushort_arg_4 <= 1);
743
constant bool has_in2_tex = (has_in2_arg && ushort_arg_2 <= 4 && ushort_arg_4 <= 1);
744
constant bool has_in3_tex = (has_in3_arg && ushort_arg_3 <= 4 && ushort_arg_4 <= 1);
746
constant bool has_in0_array = (has_in0_arg && !has_in0_tex);
747
constant bool has_in1_array = (has_in1_arg && !has_in1_tex);
748
constant bool has_in2_array = (has_in2_arg && !has_in2_tex);
749
constant bool has_in3_array = (has_in3_arg && !has_in3_tex);
751
constant bool concat_has_out_tex = (ushort_arg_4 <= 4 && ushort_arg_5 <= 1);
752
constant bool concat_has_out_array = !concat_has_out_tex;
754
inline ushort idx_3(ushort z, ushort C0, ushort C1, ushort C2, ushort C3) {
761
if (z < (C0 + C1 + C2)) {
767
inline ushort idx_2(ushort z, ushort C0, ushort C1, ushort C2) {
777
inline ushort idx_1(ushort z, ushort C0, ushort C1) {
785
inline ushort idx_0(ushort z, ushort C0) { return 0; }
787
// in a texture_array with size C, find the offset for image N at plane c.
788
inline constexpr ushort z_off(ushort n, ushort c, ushort C) { return n * divRoundUp(C, 4) + c / 4; }
791
texture2d<half, access::read> in0[[ texture(0), function_constant(has_in0_tex) ]],
792
texture2d<half, access::read> in1[[ texture(1), function_constant(has_in1_tex) ]],
793
texture2d<half, access::read> in2[[ texture(2), function_constant(has_in2_tex) ]],
794
texture2d<half, access::read> in3[[ texture(3), function_constant(has_in3_tex) ]],
795
texture2d_array<half, access::read> ina0[[ texture(0), function_constant(has_in0_array) ]],
796
texture2d_array<half, access::read> ina1[[ texture(1), function_constant(has_in1_array) ]],
797
texture2d_array<half, access::read> ina2[[ texture(2), function_constant(has_in2_array) ]],
798
texture2d_array<half, access::read> ina3[[ texture(3), function_constant(has_in3_array) ]],
799
texture2d<half, access::write> out[[texture(5),
800
function_constant(concat_has_out_tex) ]],
801
texture2d_array<half, access::write> outa[[texture(5),
802
function_constant(concat_has_out_array) ]],
803
ushort3 gid[[thread_position_in_grid]]) {
804
if (concat_has_out_tex) {
805
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
809
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
814
const ushort C0 = ushort_arg_0;
815
const ushort C1 = ushort_arg_1;
816
const ushort C2 = ushort_arg_2;
817
const ushort C3 = ushort_arg_3;
818
const ushort C = C0 + C1 + C2 + C3;
819
const ushort n = gid.z / divRoundUp(C, 4);
820
const ushort c = gid.z - n * divRoundUp(C, 4);
821
// Fill channel 4*c to 4*(c+1) of nth image of output
823
ushort2 gid_ = gid.xy;
826
for (int off = 0; off < 4; ++off) {
827
ushort cur_channel = c * 4 + off;
829
if (cur_channel >= C) {
833
cur_idx = idx_3(cur_channel, C0, C1, C2, C3);
834
} else if (has_in2_arg) {
835
cur_idx = idx_2(cur_channel, C0, C1, C2);
836
} else if (has_in1_arg) {
837
cur_idx = idx_1(cur_channel, C0, C1);
838
} else if (has_in0_arg) {
839
cur_idx = idx_0(cur_channel, C0);
847
src_off = cur_channel % 4;
850
src_off = (cur_channel - C0) % 4;
853
src_off = (cur_channel - (C0 + C1)) % 4;
856
src_off = (cur_channel - (C0 + C1 + C2)) % 4;
859
// try to see if we can only issue one read op for the 4 values
860
bool fast_path = false;
861
if (off == 0 && src_off == 0 && (cur_channel + 3) < C) {
864
last_idx = idx_3(cur_channel + 3, C0, C1, C2, C3);
865
} else if (has_in2_arg) {
866
last_idx = idx_2(cur_channel + 3, C0, C1, C2);
867
} else if (has_in1_arg) {
868
last_idx = idx_1(cur_channel + 3, C0, C1);
869
} else if (has_in0_arg) {
870
last_idx = idx_0(cur_channel + 3, C0);
875
if (cur_idx == last_idx) {
883
value = in0.read(gid_);
885
value[off] = in0.read(gid_)[src_off];
890
value = ina0.read(gid_, z_off(n, cur_channel, C0));
892
value[off] = ina0.read(gid_, z_off(n, cur_channel, C0))[src_off];
900
value = in1.read(gid_);
902
value[off] = in1.read(gid_)[src_off];
907
value = ina1.read(gid_, z_off(n, cur_channel - C0, C1));
909
value[off] = ina1.read(gid_, z_off(n, cur_channel - C0, C1))[src_off];
917
value = in2.read(gid_);
919
value[off] = in2.read(gid_)[src_off];
924
value = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2));
926
value[off] = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2))[src_off];
934
value = in3.read(gid_);
936
value[off] = in3.read(gid_)[src_off];
941
value = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3));
943
value[off] = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3))[src_off];
953
if (concat_has_out_tex) {
954
out.write(value, gid_, gid.z);
956
outa.write(value, gid_, gid.z);
962
constant bool rw_has_in_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
963
constant bool rw_has_out_arr = (ushort_arg_4 > 1 || ushort_arg_2 > 4);
964
constant bool rw_has_in_tex = (!rw_has_in_arr);
965
constant bool rw_has_out_tex = (!rw_has_out_arr);
966
kernel void roi_warp(texture2d_array<half, access::sample> ina[[texture(0), function_constant(rw_has_in_arr)]],
967
texture2d<half, access::sample> in[[texture(0), function_constant(rw_has_in_tex)]],
968
texture2d_array<half, access::write> outa[[texture(1), function_constant(rw_has_out_arr)]],
969
texture2d<half, access::write> out[[texture(1), function_constant(rw_has_out_tex)]],
970
constant half4* rois[[buffer(0)]],
971
ushort3 gid[[thread_position_in_grid]]) {
972
ushort out_width, out_height;
973
if (rw_has_out_arr) {
974
out_width = outa.get_width();
975
out_height = outa.get_height();
977
out_width = out.get_width();
978
out_height = out.get_height();
980
if (gid.x >= out_width || gid.y >= out_height) {
983
constexpr sampler s2(coord::pixel, address::clamp_to_edge, filter::linear);
985
const half spatial_scale = half(ushort_arg_0) / 10000;
986
const ushort sampling_ratio = ushort_arg_1;
987
const ushort C = ushort_arg_2;
988
const ushort pw = gid.x;
989
const ushort ph = gid.y;
990
const ushort n = gid.z / divRoundUp(C, 4);
991
const ushort c = gid.z % divRoundUp(C, 4);
993
const RoIT4 roi_scaled = rois[n] * spatial_scale;
994
const RoIT roi_start_w = roi_scaled[0];
995
const RoIT roi_start_h = roi_scaled[1];
996
const RoIT roi_end_w = roi_scaled[2];
997
const RoIT roi_end_h = roi_scaled[3];
999
// Force malformed ROIs to be 1x1
1000
const RoIT roi_width = max(roi_end_w - roi_start_w, (RoIT)1.);
1001
const RoIT roi_height = max(roi_end_h - roi_start_h, (RoIT)1.);
1003
const RoIT bin_size_h = static_cast<RoIT>(roi_height) / static_cast<RoIT>(out_height);
1004
const RoIT bin_size_w = static_cast<RoIT>(roi_width) / static_cast<RoIT>(out_width);
1005
const ushort roi_bin_grid_h = sampling_ratio > 0 ? sampling_ratio : ceil(roi_height / static_cast<RoIT>(out_height));
1006
const ushort roi_bin_grid_w = sampling_ratio > 0 ? sampling_ratio : ceil(roi_width / static_cast<RoIT>(out_width));
1007
const ushort iy_upper = (sampling_ratio > 0) ? roi_bin_grid_h : (roi_bin_grid_h + 1);
1008
const ushort ix_upper = (sampling_ratio > 0) ? roi_bin_grid_w : (roi_bin_grid_w + 1);
1010
const RoIT count = iy_upper * ix_upper;
1012
RoIT4 output_val = 0.0;
1013
for (int iy = 0; iy < iy_upper; iy++) {
1014
for (int ix = 0; ix < ix_upper; ix++) {
1016
roi_start_h + ph * bin_size_h + iy * bin_size_h / static_cast<RoIT>(roi_bin_grid_h);
1018
roi_start_w + pw * bin_size_w + ix * bin_size_w / static_cast<RoIT>(roi_bin_grid_w);
1019
if (rw_has_in_arr) {
1020
output_val += ina.sample(s2, float2(x + 0.5, y + 0.5), c);
1022
output_val += in.sample(s2, float2(x + 0.5, y + 0.5));
1026
output_val /= count;
1027
if (rw_has_out_arr) {
1028
outa.write(static_cast<half4>(output_val), gid.xy, gid.z);
1030
out.write(static_cast<half4>(output_val), gid.xy);
1034
kernel void nms(device uint* mask[[buffer(0)]],
1035
constant float* proposals[[buffer(1)]],
1036
constant int* indices[[buffer(2)]],
1037
ushort2 tgid[[threadgroup_position_in_grid]],
1038
ushort2 tid[[thread_position_in_threadgroup]]) {
1039
const ushort num_proposals = ushort_arg_0;
1040
const ushort threads_per_group = ushort_arg_1;
1041
float nms_thresh = float(ushort_arg_2) / 10000.0;
1042
const ushort global_offset = ushort_arg_3;
1043
const ushort row_start = tgid.y;
1044
const ushort col_start = tgid.x;
1045
const ushort trd_id = tid.x;
1047
const short row_size = min(short(32), short(num_proposals - row_start * threads_per_group));
1048
const short col_size = min(short(32), short(num_proposals - col_start * threads_per_group));
1050
// mask the bit if the IoU between two proposals exceeds the threshold
1051
if (trd_id < row_size) {
1052
const ushort cur_idx = global_offset + row_start * threads_per_group + trd_id;
1053
const ushort offset = indices[cur_idx] * 4;
1054
const float4 cur_proposal = float4(
1055
proposals[offset], proposals[offset + 1], proposals[offset + 2], proposals[offset + 3]);
1057
ushort group_start = 0; // start index within group
1058
if (row_start == col_start) {
1059
// if in the same group, start from the next
1060
group_start = trd_id + 1;
1062
for (ushort i = group_start; i < col_size; i++) {
1063
float4 a = cur_proposal;
1064
ushort idx = indices[global_offset + col_start * threads_per_group + i] * 4;
1065
float4 b = float4(proposals[idx], proposals[idx + 1], proposals[idx + 2], proposals[idx + 3]);
1066
float left = max(a[0], b[0]);
1067
float right = min(a[2], b[2]);
1068
float top = max(a[1], b[1]);
1069
float bottom = min(a[3], b[3]);
1070
float width = max(right - left + 1.0, 0.0);
1071
float height = max(bottom - top + 1.0, 0.0);
1072
float interS = width * height;
1073
float Sa = (a[2] - a[0] + 1.0) * (a[3] - a[1] + 1.0);
1074
float Sb = (b[2] - b[0] + 1.0) * (b[3] - b[1] + 1.0);
1075
float iou = interS / (Sa + Sb - interS);
1076
if (iou - nms_thresh > 0) {
1077
cur_mask |= 1U << i;
1080
ushort col_blocks = (num_proposals + threads_per_group - 1) / threads_per_group;
1081
mask[cur_idx * col_blocks + col_start] = cur_mask;
1085
kernel void resize_nearest(texture2d_array<half, access::sample> in[[texture(0)]],
1086
texture2d_array<half, access::write> out[[texture(1)]],
1087
ushort3 gid[[thread_position_in_grid]]) {
1088
const ushort oH = ushort_arg_0;
1089
const ushort oW = ushort_arg_1;
1090
if (gid.x >= oW || gid.y >= oH) {
1093
const float height_scale = float(ushort_arg_2) / 10000;
1094
const float width_scale = float(ushort_arg_3) / 10000;
1095
constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1096
const int in_y = (int)(gid.y / height_scale);
1097
const int in_x = (int)(gid.x / width_scale);
1098
out.write(in.sample(s, float2(in_x, in_y), gid.z), gid.xy, gid.z);
1101
kernel void resize_nearest_nonarray(texture2d<half, access::sample> in[[texture(0)]],
1102
texture2d<half, access::write> out[[texture(1)]],
1103
ushort2 gid[[thread_position_in_grid]]) {
1104
const ushort oH = ushort_arg_0;
1105
const ushort oW = ushort_arg_1;
1106
if (gid.x >= oW || gid.y >= oH) {
1109
const float height_scale = float(ushort_arg_2) / 10000;
1110
const float width_scale = float(ushort_arg_3) / 10000;
1111
constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1112
const int in_y = (int)(gid.y / height_scale);
1113
const int in_x = (int)(gid.x / width_scale);
1114
out.write(in.sample(s, float2(in_x, in_y)), gid.xy);
1117
kernel void channel_shuffle(
1118
texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
1119
texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
1120
texture2d<half, access::write> out[[texture(1), function_constant(in0_is_tex)]],
1121
texture2d_array<half, access::write> outa[[texture(1), function_constant(in0_is_arr)]],
1122
ushort3 gid[[thread_position_in_grid]]) {
1123
ushort C = ushort_arg_1;
1124
ushort K = ushort_arg_2;
1125
ushort groups = ushort_arg_3;
1128
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
1132
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
1136
const ushort n = gid.z / divRoundUp(C, 4);
1137
const ushort c = gid.z - n * divRoundUp(C, 4);
1139
ushort2 gid_ = gid.xy;
1140
for (int off = 0; off < 4; ++off) {
1141
ushort cur_channel = c * 4 + off;
1142
if (cur_channel >= C) {
1145
ushort channel_id = cur_channel / groups;
1146
ushort group_id = cur_channel % groups;
1147
ushort c0 = group_id * K + channel_id;
1149
value[off] = in0.read(gid_)[c0 % 4];
1151
value[off] = ina0.read(gid_, c0 / 4 + n * divRoundUp(C, 4))[c0 % 4];
1155
out.write(value, gid_);
1157
outa.write(value, gid_, gid.z);