pytorch

Форк
0
1158 строк · 41.8 Кб
1

2
#include <metal_stdlib>
3

4
using namespace metal;
5

6
constant ushort ushort_arg_0[[function_constant(0)]];
7
constant ushort ushort_arg_1[[function_constant(1)]];
8
constant ushort ushort_arg_2[[function_constant(2)]];
9
constant ushort ushort_arg_3[[function_constant(3)]];
10
constant ushort ushort_arg_4[[function_constant(4)]];
11
constant ushort ushort_arg_5[[function_constant(5)]];
12
constant ushort ushort_arg_6[[function_constant(6)]];
13
constant ushort ushort_arg_7[[function_constant(7)]];
14
constant ushort ushort_arg_8[[function_constant(8)]];
15
constant ushort ushort_arg_9[[function_constant(9)]];
16

17
inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
18

19
kernel void affine(constant half4* scale[[buffer(0)]],
20
                   constant half4* shift[[buffer(1)]],
21
                   texture2d_array<half, access::read> in[[texture(0)]],
22
                   texture2d_array<half, access::write> out[[texture(1)]],
23
                   ushort3 gid[[thread_position_in_grid]]) {
24
  const ushort C = ushort_arg_0;
25
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
26
    return;
27
  }
28
  const half4 scale_c = scale[gid.z % divRoundUp(C, 4)];
29
  const half4 shift_c = shift[gid.z % divRoundUp(C, 4)];
30
  ushort2 gid_(gid.x, gid.y);
31
  const half4 x = in.read(gid_, gid.z);
32
  const half4 y = scale_c * x + shift_c;
33
  out.write(y, gid_, gid.z);
34
}
35

36
kernel void affine_nonarray(constant half4* scale[[buffer(0)]],
37
                            constant half4* shift[[buffer(1)]],
38
                            texture2d<half, access::read> in[[texture(0)]],
39
                            texture2d<half, access::write> out[[texture(1)]],
40
                            ushort2 gid[[thread_position_in_grid]]) {
41
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
42
    return;
43
  }
44
  const half4 scale_c = scale[0];
45
  const half4 shift_c = shift[0];
46
  half4 x = in.read(gid);
47
  const half4 y = scale_c * x + shift_c;
48
  out.write(y, gid);
49
}
50

51
kernel void prelu_nonshared(constant half4* weights[[buffer(0)]],
52
                            texture2d_array<half, access::read> in[[texture(0)]],
53
                            texture2d_array<half, access::write> out[[texture(1)]],
54
                            ushort3 gid[[thread_position_in_grid]]) {
55
  const ushort C = ushort_arg_0;
56
  const ushort S = ushort_arg_1;
57
  const bool channel_shared = S == 1;
58
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
59
    return;
60
  }
61
  half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
62
                           : weights[gid.z % divRoundUp(C, 4)];
63
  ushort2 gid_(gid.x, gid.y);
64
  half4 x = in.read(gid_, gid.z);
65
  half4 y = select(x * w, x, x > 0.0h);
66
  out.write(y, gid_, gid.z);
67
}
68

69
kernel void prelu_nonshared_nonarray(constant half4* weights[[buffer(0)]],
70
                                     texture2d<half, access::read> in[[texture(0)]],
71
                                     texture2d<half, access::write> out[[texture(1)]],
72
                                     ushort2 gid[[thread_position_in_grid]]) {
73
  // const ushort C = ushort_arg_0;
74
  const ushort S = ushort_arg_1;
75
  const bool channel_shared = S == 1;
76
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
77
    return;
78
  }
79
  half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
80
                           : weights[0];
81
  half4 x = in.read(gid);
82
  half4 y = select(x * w, x, x > 0.0h);
83
  out.write(y, gid);
84
}
85

86
// One block per texture.
87
// 256 threads per block.
88
using AccT = float4;
89

90
constant const bool instance_norm_has_prelu = ushort_arg_1 > 0;
91

92
kernel void instance_norm(
93
    constant half4* weights[[buffer(0)]],
94
    constant half4* bias[[buffer(1)]],
95
    constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
96
    texture2d_array<half, access::read> in[[texture(0)]],
97
    texture2d_array<half, access::write> out[[texture(1)]],
98
    ushort3 gid[[thread_position_in_grid]],
99
    ushort tid[[thread_index_in_threadgroup]],
100
    ushort3 tcount[[threads_per_threadgroup]]) {
101
  if (gid.z >= out.get_array_size()) {
102
    return;
103
  }
104
  const ushort C = ushort_arg_0;
105
  const ushort S = ushort_arg_1;
106
  const bool channel_shared = S == 1;
107
  const ushort c = gid.z % divRoundUp(C, 4);
108
  constexpr ushort THREADGROUP_SIZE = 256;
109

110
  threadgroup AccT per_thread_state[THREADGROUP_SIZE];
111
  // Each block handles a single texture.
112
  per_thread_state[tid] = 0;
113
  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
114
    for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
115
      per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y), gid.z));
116
    }
117
  }
118

119
  threadgroup_barrier(mem_flags::mem_threadgroup);
120

121
  // 256 -> 32 reduction
122
  if (tid < 32) {
123
    per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
124
                             per_thread_state[tid + 96] + per_thread_state[tid + 128] +
125
                             per_thread_state[tid + 160] + per_thread_state[tid + 192] +
126
                             per_thread_state[tid + 224];
127
  }
128

129
  threadgroup_barrier(mem_flags::mem_threadgroup);
130

131
  if (tid == 0) {
132
    AccT sum = 0.0;
133
    for (ushort i = 0; i < 32; ++i) {
134
      sum += per_thread_state[i];
135
    }
136
    sum /= (in.get_width() * in.get_height());
137
    per_thread_state[0] = sum;
138
  }
139
  threadgroup_barrier(mem_flags::mem_threadgroup);
140
  // Broadcast to all threads.
141
  const AccT mean = per_thread_state[0];
142

143
  threadgroup_barrier(mem_flags::mem_threadgroup);
144

145
  per_thread_state[tid] = 0;
146
  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
147
    for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
148
      AccT delta = static_cast<AccT>(in.read(ushort2(x, y), gid.z)) - mean;
149
      per_thread_state[tid] += delta * delta;
150
    }
151
  }
152

153
  threadgroup_barrier(mem_flags::mem_threadgroup);
154

155
  // 256 -> 32 reduction
156
  if (tid < 32) {
157
    per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
158
                             per_thread_state[tid + 96] + per_thread_state[tid + 128] +
159
                             per_thread_state[tid + 160] + per_thread_state[tid + 192] +
160
                             per_thread_state[tid + 224];
161
  }
162

163
  threadgroup_barrier(mem_flags::mem_threadgroup);
164

165
  if (tid == 0) {
166
    AccT sum = 0.0;
167
    for (ushort i = 0; i < 32; ++i) {
168
      sum += per_thread_state[i];
169
    }
170
    sum /= (in.get_width() * in.get_height());
171
    per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
172
  }
173

174
  threadgroup_barrier(mem_flags::mem_threadgroup);
175
  // Broadcast to all threads.
176
  const AccT inv_var = per_thread_state[0];
177

178
  const AccT c_weights = static_cast<AccT>(weights[c]);
179
  const AccT c_bias = static_cast<AccT>(bias[c]);
180

181
  const AccT scale = inv_var * c_weights;
182
  const AccT shift = c_bias - mean * scale;
183

184
  half4 w;
185
  if (instance_norm_has_prelu) {
186
    w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[c];
187
  }
188
  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
189
    for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
190
      half4 scaled =
191
          static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y), gid.z)) * scale + shift);
192
      if (instance_norm_has_prelu) {
193
        scaled = select(scaled * w, scaled, scaled > 0.0h);
194
      }
195
      out.write(scaled, ushort2(x, y), gid.z);
196
    }
197
  }
198
}
199

200
// One block per texture.
201
// 256 threads per block.
202
kernel void instance_norm_nonarray(
203
    constant half4* weights[[buffer(0)]],
204
    constant half4* bias[[buffer(1)]],
205
    constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
206
    texture2d<half, access::read> in[[texture(0)]],
207
    texture2d<half, access::write> out[[texture(1)]],
208
    ushort3 gid[[thread_position_in_grid]],
209
    ushort tid[[thread_index_in_threadgroup]],
210
    ushort3 tcount[[threads_per_threadgroup]]) {
211
  // const ushort C = ushort_arg_0;
212
  const ushort S = ushort_arg_1;
213
  const bool channel_shared = S == 1;
214
  constexpr ushort THREADGROUP_SIZE = 256;
215

216
  threadgroup AccT per_thread_state[THREADGROUP_SIZE];
217
  // Each block handles a single texture.
218
  per_thread_state[tid] = 0;
219
  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
220
    for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
221
      per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y)));
222
    }
223
  }
224

225
  threadgroup_barrier(mem_flags::mem_threadgroup);
226

227
  // 256 -> 32 reduction
228
  if (tid < 32) {
229
    per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
230
                             per_thread_state[tid + 96] + per_thread_state[tid + 128] +
231
                             per_thread_state[tid + 160] + per_thread_state[tid + 192] +
232
                             per_thread_state[tid + 224];
233
  }
234

235
  threadgroup_barrier(mem_flags::mem_threadgroup);
236

237
  if (tid == 0) {
238
    AccT sum = 0.0;
239
    for (ushort i = 0; i < 32; ++i) {
240
      sum += per_thread_state[i];
241
    }
242
    sum /= (in.get_width() * in.get_height());
243
    per_thread_state[0] = sum;
244
  }
245
  threadgroup_barrier(mem_flags::mem_threadgroup);
246
  // Broadcast to all threads.
247
  const AccT mean = per_thread_state[0];
248

249
  threadgroup_barrier(mem_flags::mem_threadgroup);
250

251
  per_thread_state[tid] = 0;
252
  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
253
    for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
254
      AccT delta = static_cast<AccT>(in.read(ushort2(x, y))) - mean;
255
      per_thread_state[tid] += delta * delta;
256
    }
257
  }
258

259
  threadgroup_barrier(mem_flags::mem_threadgroup);
260

261
  // 256 -> 32 reduction
262
  if (tid < 32) {
263
    per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
264
                             per_thread_state[tid + 96] + per_thread_state[tid + 128] +
265
                             per_thread_state[tid + 160] + per_thread_state[tid + 192] +
266
                             per_thread_state[tid + 224];
267
  }
268

269
  threadgroup_barrier(mem_flags::mem_threadgroup);
270

271
  if (tid == 0) {
272
    AccT sum = 0.0;
273
    for (ushort i = 0; i < 32; ++i) {
274
      sum += per_thread_state[i];
275
    }
276
    sum /= (in.get_width() * in.get_height());
277
    per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
278
  }
279

280
  threadgroup_barrier(mem_flags::mem_threadgroup);
281
  // Broadcast to all threads.
282
  const AccT inv_var = per_thread_state[0];
283

284
  const AccT c_weights = static_cast<AccT>(weights[0]);
285
  const AccT c_bias = static_cast<AccT>(bias[0]);
286

287
  const AccT scale = inv_var * c_weights;
288
  const AccT shift = c_bias - mean * scale;
289

290
  half4 w;
291
  if (instance_norm_has_prelu) {
292
    w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[0];
293
  }
294
  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
295
    for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
296
      half4 scaled = static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y))) * scale + shift);
297
      if (instance_norm_has_prelu) {
298
        scaled = select(scaled * w, scaled, scaled > 0.0h);
299
      }
300
      out.write(scaled, ushort2(x, y));
301
    }
302
  }
303
}
304

305
kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
306
                               texture2d_array<half, access::write> out[[texture(0)]],
307
                               ushort3 gid[[thread_position_in_grid]]) {
308
  const ushort C = ushort_arg_0;
309
  const ushort H = ushort_arg_1;
310
  const ushort W = ushort_arg_2;
311
  if (gid.x >= W || gid.y >= H) {
312
    return;
313
  }
314

315
  const ushort n = gid.z / divRoundUp(C, 4);
316
  const ushort c = gid.z - n * divRoundUp(C, 4);
317

318
// TODO: are the `else` branches needed?
319
// TODO: trick the optimizer for case where C == 4?
320
#define CHW_TO_CHWP4(idx, n, c_, h, w)                                     \
321
  if ((c_) < C) {                                                          \
322
    trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \
323
  } else {                                                                 \
324
    trns[idx] = 0.0h;                                                      \
325
  }
326

327
  half4 trns;
328
  CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x);
329
  CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x);
330
  CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x);
331
  CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x);
332
#undef CHW_TO_CHWP4
333

334
  out.write(trns, gid.xy, gid.z);
335
}
336

337
kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]],
338
                                        texture2d<half, access::write> out[[texture(0)]],
339
                                        ushort2 gid[[thread_position_in_grid]]) {
340
  const ushort C = ushort_arg_0;
341
  const ushort H = ushort_arg_1;
342
  const ushort W = ushort_arg_2;
343

344
  if (gid.x >= W || gid.y >= H) {
345
    return;
346
  }
347

348
  half4 trns;
349
// TODO: are the `else` branches needed?
350
// TODO: trick the optimizer for case where C % 4 == 0?
351

352
#define CHW_TO_CHWP4(idx, c, h, w)                        \
353
  if ((c) < C) {                                          \
354
    trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \
355
  } else {                                                \
356
    trns[idx] = 0.0h;                                     \
357
  }
358

359
  CHW_TO_CHWP4(0, 0, gid.y, gid.x);
360
  CHW_TO_CHWP4(1, 1, gid.y, gid.x);
361
  CHW_TO_CHWP4(2, 2, gid.y, gid.x);
362
  CHW_TO_CHWP4(3, 3, gid.y, gid.x);
363
#undef CHW_TO_CHWP4
364

365
  out.write(trns, gid.xy);
366
}
367

368
kernel void copy_metal_to_nchw(texture2d_array<half, access::read> in[[texture(0)]],
369
                               device float* out[[buffer(0)]],
370
                               ushort3 gid[[thread_position_in_grid]]) {
371
  const ushort C = ushort_arg_0;
372
  const ushort H = ushort_arg_1;
373
  const ushort W = ushort_arg_2;
374

375
  if (gid.x >= W || gid.y >= H) {
376
    return;
377
  }
378
  const ushort n = gid.z / divRoundUp(C, 4);
379
  const ushort c = gid.z - n * divRoundUp(C, 4);
380

381
  half4 cs = in.read(gid.xy, gid.z);
382

383
#define CHWP4_TO_CHW(idx, n, c_, h, w)                                    \
384
  if ((c_) < C) {                                                         \
385
    out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \
386
  }
387

388
  CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x);
389
  CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x);
390
  CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x);
391
  CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x);
392
#undef CHWP4_TO_CHW
393
}
394

395
kernel void copy_metal_to_nchw_nonarray(texture2d<half, access::read> in[[texture(0)]],
396
                                        device float* out[[buffer(0)]],
397
                                        ushort2 gid[[thread_position_in_grid]]) {
398
  const ushort C = ushort_arg_0;
399
  const ushort H = ushort_arg_1;
400
  const ushort W = ushort_arg_2;
401

402
  if (gid.x >= W || gid.y >= H) {
403
    return;
404
  }
405

406
  half4 cs = in.read(gid.xy);
407

408
#define CHWP4_TO_CHW(idx, c, h, w)                       \
409
  if ((c) < C) {                                         \
410
    out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \
411
  }
412

413
  CHWP4_TO_CHW(0, 0, gid.y, gid.x);
414
  CHWP4_TO_CHW(1, 1, gid.y, gid.x);
415
  CHWP4_TO_CHW(2, 2, gid.y, gid.x);
416
  CHWP4_TO_CHW(3, 3, gid.y, gid.x);
417
#undef CHWP4_TO_CHW
418
}
419

420
kernel void convtranspose_upscale(texture2d_array<half, access::read> in[[texture(0)]],
421
                                  texture2d_array<half, access::write> out[[texture(1)]],
422
                                  ushort3 gid[[thread_position_in_grid]]) {
423
  // All resolved at compile time.
424
  // Assume symmetric kernel/stride/pad for now.
425
  const ushort kernel_ = ushort_arg_0;
426
  const ushort stride = ushort_arg_1;
427
  const ushort pad = ushort_arg_2;
428

429
  half4 zero(0.0h, 0.0h, 0.0h, 0.0h);
430

431
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
432
    return;
433
  }
434
  const ushort2 gid_ = gid.xy;
435
  if (gid.x < kernel_ - 1 - pad || gid.y < kernel_ - 1 - pad) {
436
    out.write(zero, gid_, gid.z);
437
    return;
438
  }
439

440
  if (((gid.x - (kernel_ - 1 - pad)) % stride == 0) &&
441
      ((gid.y - (kernel_ - 1 - pad)) % stride == 0)) {
442
    ushort2 in_pos((gid.x - (kernel_ - 1 - pad)) / stride, (gid.y - (kernel_ - 1 - pad)) / stride);
443

444
    if (in_pos.x < in.get_width() && in_pos.y < in.get_height()) {
445
      half4 input = in.read(in_pos, gid.z);
446
      out.write(input, gid_, gid.z);
447
    } else {
448
      out.write(zero, gid_, gid.z);
449
    }
450
  } else {
451
    out.write(zero, gid_, gid.z);
452
  }
453
}
454

455
constant bool has_in_arr = (ushort_arg_7 > 1 || ushort_arg_0 * ushort_arg_1 * ushort_arg_6 > 4);
456
constant bool has_out_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
457
constant bool has_in_tex = (!has_in_arr);
458
constant bool has_out_tex = (!has_out_arr);
459

460
kernel void col2im(
461
    texture2d_array<half, access::read> ina[[ texture(0), function_constant(has_in_arr) ]],
462
    texture2d<half, access::read> in[[ texture(0), function_constant(has_in_tex) ]],
463
    texture2d_array<half, access::write> outa[[ texture(1), function_constant(has_out_arr) ]],
464
    texture2d<half, access::write> out[[ texture(1), function_constant(has_out_tex) ]],
465
    constant half4* bias[[buffer(0)]],
466
    ushort3 gid[[thread_position_in_grid]]) {
467
  if (has_out_tex) {
468
    if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
469
      return;
470
    }
471
  } else {
472
    if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
473
      return;
474
    }
475
  }
476
  const ushort kernel_h = ushort_arg_0;
477
  const ushort kernel_w = ushort_arg_1;
478
  const ushort stride_h = ushort_arg_2;
479
  const ushort stride_w = ushort_arg_3;
480
  const ushort pad_l = ushort_arg_4;
481
  const ushort pad_t = ushort_arg_5;
482
  const ushort C = ushort_arg_6;
483
  //  const int N = ushort_arg_7;
484
  const ushort height_col = ushort_arg_8; //(outa.get_height() + pad + pad - kernel_) / stride + 1;
485
  const ushort width_col = ushort_arg_9; // (outa.get_width() + pad + pad - kernel_) / stride + 1;
486

487
  const ushort n = gid.z / divRoundUp(C, 4);
488
  const ushort c = gid.z - n * divRoundUp(C, 4);
489

490
  const ushort w = gid.x + pad_l;
491
  const ushort h = gid.y + pad_t;
492

493
  // compute the start and end of the output
494
  const ushort w_col_start = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
495
  const ushort w_col_end = min(ushort(w / stride_w + 1), ushort(width_col));
496
  const ushort h_col_start = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
497
  const ushort h_col_end = min(ushort(h / stride_h + 1), ushort(height_col));
498

499
  float4 val = static_cast<float4>(bias[c]);
500
  for (ushort h_col = h_col_start; h_col < h_col_end; ++h_col) {
501
    for (ushort w_col = w_col_start; w_col < w_col_end; ++w_col) {
502
      const ushort w_k = w - w_col * stride_w;
503
      const ushort h_k = h - h_col * stride_h;
504

505
      // layout is essentially: [N][K][K][C][H][W]
506
      // - where the divRoundUp(K * K * C, 4) channels are interleaved as usual.
507
      // Thus, it's actually [N][divRoundUp(K * K * C, 4)][H][W].
508

509
      // If C % 4 is not zero, then we have to play some games via partial indexing.
510
      // TODO: is it worth optimizing this loop via padding in C?
511
      if (C % 4 == 0) {
512
        ushort c_col = n * kernel_h * kernel_w * divRoundUp(C, 4) +
513
                       h_k * kernel_w * divRoundUp(C, 4) + w_k * divRoundUp(C, 4) + c;
514
        if (has_in_arr) {
515
          val += static_cast<float4>(ina.read(ushort2(w_col, h_col), c_col));
516
        }
517
        if (has_in_tex) {
518
          val += static_cast<float4>(in.read(ushort2(w_col, h_col), c_col));
519
        }
520
      } else {
521
        half4 components(0, 0, 0, 0);
522
        for (auto i = 0; i < 4; ++i) {
523
          ushort c_col_i = n * divRoundUp(kernel_h * kernel_w * C, 4) * 4 + h_k * kernel_w * C +
524
                           w_k * C + c * 4 + i;
525
          ushort c_col_i_z = c_col_i / 4;
526
          ushort c_col_i_off = c_col_i - c_col_i_z * 4;
527
          if (has_in_arr) {
528
            components[i] = ina.read(ushort2(w_col, h_col), c_col_i_z)[c_col_i_off];
529
          }
530
          if (has_in_tex) {
531
            components[i] = in.read(ushort2(w_col, h_col))[c_col_i_off];
532
          }
533
        }
534
        val += static_cast<float4>(components);
535
      }
536
    }
537
  }
538
  if (has_out_arr) {
539
    outa.write(static_cast<half4>(val), gid.xy, gid.z);
540
  }
541
  if (has_out_tex) {
542
    out.write(static_cast<half4>(val), gid.xy);
543
  }
544
}
545

546
kernel void preprocess_stylizer(device uchar4* in[[buffer(0)]],
547
                                constant half* mean[[buffer(1)]],
548
                                constant half4* noise[[buffer(2)]],
549
                                texture2d<half, access::write> out[[texture(0)]],
550
                                ushort2 gid[[thread_position_in_grid]]) {
551

552
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
553
    return;
554
  }
555
  const ushort noise_size = ushort_arg_0;
556

557
  half4 mean_half(mean[0], mean[1], mean[2], 0.0h);
558
  uint input_noise_idx = ((uint)out.get_width() * (uint)gid.y + (uint)gid.x) % (noise_size / 4);
559
  const half4 input_noise = noise[input_noise_idx];
560
  const uint W = out.get_width();
561
#define in_at(h, w) in[(uint)(h)*W + (uint)(w)]
562
  uchar4 input = in_at(gid.y, gid.x);
563
#undef in_at
564
  half4 input_half = static_cast<half4>(input);
565
  out.write(input_half - mean_half + input_noise, gid);
566
}
567

568
kernel void deprocess_stylizer(texture2d<half, access::read> in[[texture(0)]],
569
                               device uchar4* out[[buffer(0)]],
570
                               constant half* mean[[buffer(1)]],
571
                               ushort2 gid[[thread_position_in_grid]]) {
572
  if (gid.x >= in.get_width() || gid.y >= in.get_height()) {
573
    return;
574
  }
575

576
  half4 value = in.read(gid);
577

578
  half4 mean_h(mean[0], mean[1], mean[2], 0.0h);
579
  half4 min_h(0.0h, 0.0h, 0.0h, 255.0h);
580
  half4 max_h(255.0h, 255.0h, 255.0h, 255.0h);
581
  half4 clamped = clamp(value + mean_h, min_h, max_h);
582
  const uint W = in.get_width();
583
#define out_at(h, w, v) out[(uint)(h)*W + (uint)(w)] = (v)
584
  out_at(gid.y, gid.x, static_cast<uchar4>(clamped));
585
#undef out_at
586
}
587

588
kernel void reflection_padding_nonarray(texture2d<half, access::read> in[[texture(0)]],
589
                                        texture2d<half, access::write> out[[texture(1)]],
590
                                        ushort2 gid[[thread_position_in_grid]]) {
591
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
592
    return;
593
  }
594
  ushort H = in.get_height();
595
  ushort PH = out.get_height();
596

597
  // Note: we assume symmetric padding on H/W here, which is verified
598
  // in the calling code.
599
  ushort pad_h = (PH - H) / 2;
600
  ushort W = in.get_width();
601
  ushort PW = out.get_width();
602
  ushort pad_w = (PW - W) / 2;
603

604
  short h = short(gid.y) - short(pad_h);
605
  h = max(h, short(-h));
606
  h = min(h, short(2 * H - h - 2));
607

608
  short w = short(gid.x) - short(pad_w);
609
  w = max(w, short(-w));
610
  w = min(w, short(2 * W - w - 2));
611

612
  ushort2 inid(w, h);
613
  out.write(in.read(inid), gid);
614
}
615

616
kernel void reflection_padding(texture2d_array<half, access::read> in[[texture(0)]],
617
                               texture2d_array<half, access::write> out[[texture(1)]],
618
                               ushort3 gid[[thread_position_in_grid]]) {
619
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
620
    return;
621
  }
622
  ushort H = in.get_height();
623
  ushort PH = out.get_height();
624

625
  // Note: we assume symmetric padding on H/W here, which is verified
626
  // in the calling code.
627
  ushort pad_h = (PH - H) / 2;
628
  ushort W = in.get_width();
629
  ushort PW = out.get_width();
630
  ushort pad_w = (PW - W) / 2;
631

632
  short h = short(gid.y) - short(pad_h);
633
  h = max(h, short(-h));
634
  h = min(h, short(2 * H - h - 2));
635

636
  short w = short(gid.x) - short(pad_w);
637
  w = max(w, short(-w));
638
  w = min(w, short(2 * W - w - 2));
639

640
  ushort2 inid(w, h);
641

642
  out.write(in.read(inid, gid.z), gid.xy, gid.z);
643
}
644

645
kernel void bilinear_upsample(texture2d<half, access::sample> in[[texture(0)]],
646
                              texture2d<half, access::write> out[[texture(1)]],
647
                              ushort2 gid[[thread_position_in_grid]]) {
648
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
649
    return;
650
  }
651
  ushort2 src = gid / 2;
652
  constexpr sampler sampler(address::clamp_to_edge, filter::linear, coord::pixel);
653
  half4 value = in.sample(sampler, static_cast<float2>(src));
654
  out.write(value, gid);
655
}
656

657
constant bool in0_is_tex = ushort_arg_0 <= 1 && ushort_arg_1 <= 4;
658
constant bool in0_is_arr = !in0_is_tex;
659

660
kernel void elementwise_mul(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
661
                            texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
662
                            texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
663
                            texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
664
                            constant float* in1[[buffer(1)]],
665
                            ushort3 gid[[thread_position_in_grid]]) {
666
  ushort last_dim = ushort_arg_2;
667
  ushort idx;
668
  if (in0_is_tex) {
669
    if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
670
      return;
671
    }
672
    idx = gid.y * out.get_width() + gid.x;
673
  } else {
674
    if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
675
      return;
676
    }
677
    idx = gid.y * outa.get_width() + gid.x;
678
  }
679
  ushort2 gid_ = gid.xy;
680
  if (in0_is_tex) {
681
    out.write(in0.read(gid_) * in1[idx % last_dim], gid_);
682
  } else {
683
    outa.write(ina0.read(gid_, gid.z) * in1[idx % last_dim], gid_, gid.z);
684
  }
685
}
686

687
kernel void elementwise_sub(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
688
                            texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
689
                            texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
690
                            texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
691
                            constant float* in1[[buffer(1)]],
692
                            ushort3 gid[[thread_position_in_grid]]) {
693
  ushort last_dim = ushort_arg_2;
694
  ushort idx;
695
  if (in0_is_tex) {
696
    if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
697
      return;
698
    }
699
    idx = gid.y * out.get_width() + gid.x;
700
  } else {
701
    if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
702
      return;
703
    }
704
    idx = gid.y * outa.get_width() + gid.x;
705
  }
706
  ushort2 gid_ = gid.xy;
707
  if (in0_is_tex) {
708
    out.write(in0.read(gid_) - in1[idx % last_dim], gid_);
709
  } else {
710
    outa.write(ina0.read(gid_, gid.z) - in1[idx % last_dim], gid_, gid.z);
711
  }
712
}
713

714

715
kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
716
                                     texture2d<half, access::read> in1[[texture(1)]],
717
                                     texture2d<half, access::write> out[[texture(2)]],
718
                                     ushort2 gid[[thread_position_in_grid]]) {
719
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
720
    return;
721
  }
722
  out.write(in0.read(gid) + in1.read(gid), gid);
723
}
724

725
kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
726
                            texture2d_array<half, access::read> in1[[texture(1)]],
727
                            texture2d_array<half, access::write> out[[texture(2)]],
728
                            ushort3 gid[[thread_position_in_grid]]) {
729
  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
730
    return;
731
  }
732
  ushort2 gid_ = gid.xy;
733
  out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z);
734
}
735

736
constant bool has_in0_arg = (ushort_arg_0 > 0);
737
constant bool has_in1_arg = (ushort_arg_1 > 0);
738
constant bool has_in2_arg = (ushort_arg_2 > 0);
739
constant bool has_in3_arg = (ushort_arg_3 > 0);
740

741
constant bool has_in0_tex = (has_in0_arg && ushort_arg_0 <= 4 && ushort_arg_4 <= 1);
742
constant bool has_in1_tex = (has_in1_arg && ushort_arg_1 <= 4 && ushort_arg_4 <= 1);
743
constant bool has_in2_tex = (has_in2_arg && ushort_arg_2 <= 4 && ushort_arg_4 <= 1);
744
constant bool has_in3_tex = (has_in3_arg && ushort_arg_3 <= 4 && ushort_arg_4 <= 1);
745

746
constant bool has_in0_array = (has_in0_arg && !has_in0_tex);
747
constant bool has_in1_array = (has_in1_arg && !has_in1_tex);
748
constant bool has_in2_array = (has_in2_arg && !has_in2_tex);
749
constant bool has_in3_array = (has_in3_arg && !has_in3_tex);
750

751
constant bool concat_has_out_tex = (ushort_arg_4 <= 4 && ushort_arg_5 <= 1);
752
constant bool concat_has_out_array = !concat_has_out_tex;
753

754
inline ushort idx_3(ushort z, ushort C0, ushort C1, ushort C2, ushort C3) {
755
  if (z < C0) {
756
    return 0;
757
  }
758
  if (z < (C0 + C1)) {
759
    return 1;
760
  }
761
  if (z < (C0 + C1 + C2)) {
762
    return 2;
763
  }
764
  return 3;
765
}
766

767
inline ushort idx_2(ushort z, ushort C0, ushort C1, ushort C2) {
768
  if (z < C0) {
769
    return 0;
770
  }
771
  if (z < (C0 + C1)) {
772
    return 1;
773
  }
774
  return 2;
775
}
776

777
inline ushort idx_1(ushort z, ushort C0, ushort C1) {
778
  if (z < C0) {
779
    return 0;
780
  } else {
781
    return 1;
782
  }
783
}
784

785
inline ushort idx_0(ushort z, ushort C0) { return 0; }
786

787
// in a texture_array with size C, find the offset for image N at plane c.
788
inline constexpr ushort z_off(ushort n, ushort c, ushort C) { return n * divRoundUp(C, 4) + c / 4; }
789

790
kernel void concat(
791
                   texture2d<half, access::read> in0[[ texture(0), function_constant(has_in0_tex) ]],
792
                   texture2d<half, access::read> in1[[ texture(1), function_constant(has_in1_tex) ]],
793
                   texture2d<half, access::read> in2[[ texture(2), function_constant(has_in2_tex) ]],
794
                   texture2d<half, access::read> in3[[ texture(3), function_constant(has_in3_tex) ]],
795
                   texture2d_array<half, access::read> ina0[[ texture(0), function_constant(has_in0_array) ]],
796
                   texture2d_array<half, access::read> ina1[[ texture(1), function_constant(has_in1_array) ]],
797
                   texture2d_array<half, access::read> ina2[[ texture(2), function_constant(has_in2_array) ]],
798
                   texture2d_array<half, access::read> ina3[[ texture(3), function_constant(has_in3_array) ]],
799
                   texture2d<half, access::write> out[[texture(5),
800
                                                       function_constant(concat_has_out_tex) ]],
801
                   texture2d_array<half, access::write> outa[[texture(5),
802
                                                              function_constant(concat_has_out_array) ]],
803
                   ushort3 gid[[thread_position_in_grid]]) {
804
  if (concat_has_out_tex) {
805
    if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
806
      return;
807
    }
808
  } else {
809
    if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
810
      return;
811
    }
812
  }
813
  
814
  const ushort C0 = ushort_arg_0;
815
  const ushort C1 = ushort_arg_1;
816
  const ushort C2 = ushort_arg_2;
817
  const ushort C3 = ushort_arg_3;
818
  const ushort C = C0 + C1 + C2 + C3;
819
  const ushort n = gid.z / divRoundUp(C, 4);
820
  const ushort c = gid.z - n * divRoundUp(C, 4);
821
  // Fill channel 4*c to 4*(c+1) of nth image of output
822
  
823
  ushort2 gid_ = gid.xy;
824
  half4 value;
825
  
826
  for (int off = 0; off < 4; ++off) {
827
    ushort cur_channel = c * 4 + off;
828
    ushort cur_idx = 0;
829
    if (cur_channel >= C) {
830
      break;
831
    }
832
    if (has_in3_arg) {
833
      cur_idx = idx_3(cur_channel, C0, C1, C2, C3);
834
    } else if (has_in2_arg) {
835
      cur_idx = idx_2(cur_channel, C0, C1, C2);
836
    } else if (has_in1_arg) {
837
      cur_idx = idx_1(cur_channel, C0, C1);
838
    } else if (has_in0_arg) {
839
      cur_idx = idx_0(cur_channel, C0);
840
    } else {
841
      // never reached.
842
      cur_idx = 0;
843
    }
844
    ushort src_off = 0;
845
    switch (cur_idx) {
846
      case 0:
847
        src_off = cur_channel % 4;
848
        break;
849
      case 1:
850
        src_off = (cur_channel - C0) % 4;
851
        break;
852
      case 2:
853
        src_off = (cur_channel - (C0 + C1)) % 4;
854
        break;
855
      case 3:
856
        src_off = (cur_channel - (C0 + C1 + C2)) % 4;
857
        break;
858
    }
859
    // try to see if we can only issue one read op for the 4 values
860
    bool fast_path = false;
861
    if (off == 0 && src_off == 0 && (cur_channel + 3) < C) {
862
      ushort last_idx = 0;
863
      if (has_in3_arg) {
864
        last_idx = idx_3(cur_channel + 3, C0, C1, C2, C3);
865
      } else if (has_in2_arg) {
866
        last_idx = idx_2(cur_channel + 3, C0, C1, C2);
867
      } else if (has_in1_arg) {
868
        last_idx = idx_1(cur_channel + 3, C0, C1);
869
      } else if (has_in0_arg) {
870
        last_idx = idx_0(cur_channel + 3, C0);
871
      } else {
872
        // never reached.
873
        last_idx = 0;
874
      }
875
      if (cur_idx == last_idx) {
876
        fast_path = true;
877
      }
878
    }
879
    switch (cur_idx) {
880
      case 0: {
881
        if (has_in0_tex) {
882
          if (fast_path) {
883
            value = in0.read(gid_);
884
          } else {
885
            value[off] = in0.read(gid_)[src_off];
886
          }
887
        }
888
        if (has_in0_array) {
889
          if (fast_path) {
890
            value = ina0.read(gid_, z_off(n, cur_channel, C0));
891
          } else {
892
            value[off] = ina0.read(gid_, z_off(n, cur_channel, C0))[src_off];
893
          }
894
        }
895
        break;
896
      }
897
      case 1: {
898
        if (has_in1_tex) {
899
          if (fast_path) {
900
            value = in1.read(gid_);
901
          } else {
902
            value[off] = in1.read(gid_)[src_off];
903
          }
904
        }
905
        if (has_in1_array) {
906
          if (fast_path) {
907
            value = ina1.read(gid_, z_off(n, cur_channel - C0, C1));
908
          } else {
909
            value[off] = ina1.read(gid_, z_off(n, cur_channel - C0, C1))[src_off];
910
          }
911
        }
912
        break;
913
      }
914
      case 2: {
915
        if (has_in2_tex) {
916
          if (fast_path) {
917
            value = in2.read(gid_);
918
          } else {
919
            value[off] = in2.read(gid_)[src_off];
920
          }
921
        }
922
        if (has_in2_array) {
923
          if (fast_path) {
924
            value = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2));
925
          } else {
926
            value[off] = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2))[src_off];
927
          }
928
        }
929
        break;
930
      }
931
      case 3: {
932
        if (has_in3_tex) {
933
          if (fast_path) {
934
            value = in3.read(gid_);
935
          } else {
936
            value[off] = in3.read(gid_)[src_off];
937
          }
938
        }
939
        if (has_in3_array) {
940
          if (fast_path) {
941
            value = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3));
942
          } else {
943
            value[off] = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3))[src_off];
944
          }
945
        }
946
        break;
947
      }
948
    }
949
    if (fast_path) {
950
      break;
951
    }
952
  }
953
  if (concat_has_out_tex) {
954
    out.write(value, gid_, gid.z);
955
  } else {
956
    outa.write(value, gid_, gid.z);
957
  }
958
}
959

960
using RoIT = half;
961
using RoIT4 = half4;
962
constant bool rw_has_in_arr = (ushort_arg_3 > 1 ||  ushort_arg_2 > 4);
963
constant bool rw_has_out_arr = (ushort_arg_4 > 1 || ushort_arg_2 > 4);
964
constant bool rw_has_in_tex = (!rw_has_in_arr);
965
constant bool rw_has_out_tex = (!rw_has_out_arr);
966
kernel void roi_warp(texture2d_array<half, access::sample> ina[[texture(0), function_constant(rw_has_in_arr)]],
967
                     texture2d<half, access::sample> in[[texture(0), function_constant(rw_has_in_tex)]],
968
                     texture2d_array<half, access::write> outa[[texture(1), function_constant(rw_has_out_arr)]],
969
                     texture2d<half, access::write> out[[texture(1), function_constant(rw_has_out_tex)]],
970
                     constant half4* rois[[buffer(0)]],
971
                     ushort3 gid[[thread_position_in_grid]]) {
972
  ushort out_width, out_height;
973
  if (rw_has_out_arr) {
974
    out_width = outa.get_width();
975
    out_height = outa.get_height();
976
  } else {
977
    out_width = out.get_width();
978
    out_height = out.get_height();
979
  }
980
  if (gid.x >= out_width || gid.y >= out_height) {
981
    return;
982
  }
983
  constexpr sampler s2(coord::pixel, address::clamp_to_edge, filter::linear);
984
  
985
  const half spatial_scale = half(ushort_arg_0) / 10000;
986
  const ushort sampling_ratio = ushort_arg_1;
987
  const ushort C = ushort_arg_2;
988
  const ushort pw = gid.x;
989
  const ushort ph = gid.y;
990
  const ushort n = gid.z / divRoundUp(C, 4);
991
  const ushort c = gid.z % divRoundUp(C, 4);
992
  
993
  const RoIT4 roi_scaled = rois[n] * spatial_scale;
994
  const RoIT roi_start_w = roi_scaled[0];
995
  const RoIT roi_start_h = roi_scaled[1];
996
  const RoIT roi_end_w = roi_scaled[2];
997
  const RoIT roi_end_h = roi_scaled[3];
998
  
999
  // Force malformed ROIs to be 1x1
1000
  const RoIT roi_width = max(roi_end_w - roi_start_w, (RoIT)1.);
1001
  const RoIT roi_height = max(roi_end_h - roi_start_h, (RoIT)1.);
1002
  
1003
  const RoIT bin_size_h = static_cast<RoIT>(roi_height) / static_cast<RoIT>(out_height);
1004
  const RoIT bin_size_w = static_cast<RoIT>(roi_width) / static_cast<RoIT>(out_width);
1005
  const ushort roi_bin_grid_h = sampling_ratio > 0 ? sampling_ratio : ceil(roi_height / static_cast<RoIT>(out_height));
1006
  const ushort roi_bin_grid_w = sampling_ratio > 0 ? sampling_ratio : ceil(roi_width / static_cast<RoIT>(out_width));
1007
  const ushort iy_upper = (sampling_ratio > 0) ? roi_bin_grid_h : (roi_bin_grid_h + 1);
1008
  const ushort ix_upper = (sampling_ratio > 0) ? roi_bin_grid_w : (roi_bin_grid_w + 1);
1009
  
1010
  const RoIT count = iy_upper * ix_upper;
1011
  
1012
  RoIT4 output_val = 0.0;
1013
  for (int iy = 0; iy < iy_upper; iy++) {
1014
    for (int ix = 0; ix < ix_upper; ix++) {
1015
      const RoIT y =
1016
      roi_start_h + ph * bin_size_h + iy * bin_size_h / static_cast<RoIT>(roi_bin_grid_h);
1017
      const RoIT x =
1018
      roi_start_w + pw * bin_size_w + ix * bin_size_w / static_cast<RoIT>(roi_bin_grid_w);
1019
      if (rw_has_in_arr) {
1020
        output_val += ina.sample(s2, float2(x + 0.5, y + 0.5), c);
1021
      } else {
1022
        output_val += in.sample(s2, float2(x + 0.5, y + 0.5));
1023
      }
1024
    }
1025
  }
1026
  output_val /= count;
1027
  if (rw_has_out_arr) {
1028
    outa.write(static_cast<half4>(output_val), gid.xy, gid.z);
1029
  } else {
1030
    out.write(static_cast<half4>(output_val), gid.xy);
1031
  }
1032
}
1033

1034
kernel void nms(device uint* mask[[buffer(0)]],
1035
                constant float* proposals[[buffer(1)]],
1036
                constant int* indices[[buffer(2)]],
1037
                ushort2 tgid[[threadgroup_position_in_grid]],
1038
                ushort2 tid[[thread_position_in_threadgroup]]) {
1039
  const ushort num_proposals = ushort_arg_0;
1040
  const ushort threads_per_group = ushort_arg_1;
1041
  float nms_thresh = float(ushort_arg_2) / 10000.0;
1042
  const ushort global_offset = ushort_arg_3;
1043
  const ushort row_start = tgid.y;
1044
  const ushort col_start = tgid.x;
1045
  const ushort trd_id = tid.x;
1046

1047
  const short row_size = min(short(32), short(num_proposals - row_start * threads_per_group));
1048
  const short col_size = min(short(32), short(num_proposals - col_start * threads_per_group));
1049

1050
  // mask the bit if the IoU between two proposals exceeds the threshold
1051
  if (trd_id < row_size) {
1052
    const ushort cur_idx = global_offset + row_start * threads_per_group + trd_id;
1053
    const ushort offset = indices[cur_idx] * 4;
1054
    const float4 cur_proposal = float4(
1055
        proposals[offset], proposals[offset + 1], proposals[offset + 2], proposals[offset + 3]);
1056
    uint cur_mask = 0;
1057
    ushort group_start = 0; // start index within group
1058
    if (row_start == col_start) {
1059
      // if in the same group, start from the next
1060
      group_start = trd_id + 1;
1061
    }
1062
    for (ushort i = group_start; i < col_size; i++) {
1063
      float4 a = cur_proposal;
1064
      ushort idx = indices[global_offset + col_start * threads_per_group + i] * 4;
1065
      float4 b = float4(proposals[idx], proposals[idx + 1], proposals[idx + 2], proposals[idx + 3]);
1066
      float left = max(a[0], b[0]);
1067
      float right = min(a[2], b[2]);
1068
      float top = max(a[1], b[1]);
1069
      float bottom = min(a[3], b[3]);
1070
      float width = max(right - left + 1.0, 0.0);
1071
      float height = max(bottom - top + 1.0, 0.0);
1072
      float interS = width * height;
1073
      float Sa = (a[2] - a[0] + 1.0) * (a[3] - a[1] + 1.0);
1074
      float Sb = (b[2] - b[0] + 1.0) * (b[3] - b[1] + 1.0);
1075
      float iou = interS / (Sa + Sb - interS);
1076
      if (iou - nms_thresh > 0) {
1077
        cur_mask |= 1U << i;
1078
      }
1079
    }
1080
    ushort col_blocks = (num_proposals + threads_per_group - 1) / threads_per_group;
1081
    mask[cur_idx * col_blocks + col_start] = cur_mask;
1082
  }
1083
}
1084

1085
kernel void resize_nearest(texture2d_array<half, access::sample> in[[texture(0)]],
1086
                           texture2d_array<half, access::write> out[[texture(1)]],
1087
                           ushort3 gid[[thread_position_in_grid]]) {
1088
  const ushort oH = ushort_arg_0;
1089
  const ushort oW = ushort_arg_1;
1090
  if (gid.x >= oW || gid.y >= oH) {
1091
    return;
1092
  }
1093
  const float height_scale = float(ushort_arg_2) / 10000;
1094
  const float width_scale = float(ushort_arg_3) / 10000;
1095
  constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1096
  const int in_y = (int)(gid.y / height_scale);
1097
  const int in_x = (int)(gid.x / width_scale);
1098
  out.write(in.sample(s, float2(in_x, in_y), gid.z), gid.xy, gid.z);
1099
}
1100

1101
kernel void resize_nearest_nonarray(texture2d<half, access::sample> in[[texture(0)]],
1102
                                    texture2d<half, access::write> out[[texture(1)]],
1103
                                    ushort2 gid[[thread_position_in_grid]]) {
1104
  const ushort oH = ushort_arg_0;
1105
  const ushort oW = ushort_arg_1;
1106
  if (gid.x >= oW || gid.y >= oH) {
1107
    return;
1108
  }
1109
  const float height_scale = float(ushort_arg_2) / 10000;
1110
  const float width_scale = float(ushort_arg_3) / 10000;
1111
  constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1112
  const int in_y = (int)(gid.y / height_scale);
1113
  const int in_x = (int)(gid.x / width_scale);
1114
  out.write(in.sample(s, float2(in_x, in_y)), gid.xy);
1115
}
1116

1117
kernel void channel_shuffle(
1118
                            texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
1119
                            texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
1120
                            texture2d<half, access::write> out[[texture(1), function_constant(in0_is_tex)]],
1121
                            texture2d_array<half, access::write> outa[[texture(1), function_constant(in0_is_arr)]],
1122
                            ushort3 gid[[thread_position_in_grid]]) {
1123
  ushort C = ushort_arg_1;
1124
  ushort K = ushort_arg_2;
1125
  ushort groups = ushort_arg_3;
1126
  
1127
  if (in0_is_tex) {
1128
    if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
1129
      return;
1130
    }
1131
  } else {
1132
    if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
1133
      return;
1134
    }
1135
  }
1136
  const ushort n = gid.z / divRoundUp(C, 4);
1137
  const ushort c = gid.z - n * divRoundUp(C, 4);
1138
  half4 value;
1139
  ushort2 gid_ = gid.xy;
1140
  for (int off = 0; off < 4; ++off) {
1141
    ushort cur_channel = c * 4 + off;
1142
    if (cur_channel >= C) {
1143
      break;
1144
    }
1145
    ushort channel_id = cur_channel / groups;
1146
    ushort group_id = cur_channel % groups;
1147
    ushort c0 = group_id * K + channel_id;
1148
    if (in0_is_tex) {
1149
      value[off] = in0.read(gid_)[c0 % 4];
1150
    } else {
1151
      value[off] = ina0.read(gid_, c0 / 4 + n * divRoundUp(C, 4))[c0 % 4];
1152
    }
1153
  }
1154
  if (in0_is_tex) {
1155
    out.write(value, gid_);
1156
  } else {
1157
    outa.write(value, gid_, gid.z);
1158
  }
1159
}
1160

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.