2
import triton.language as tl
6
def promote_to_tensor(x):
8
return x + tl.zeros((1,), tl.int1)
13
return promote_to_tensor(x).dtype.is_floating()
17
def _prod_accumulate(a, b):
23
return tl.reduce(input, axis, _prod_accumulate)
31
return tl.where(mask, a, b)
39
return tl.where(mask, a, b)
44
return tl.reduce(a, dim, minimum)
49
return tl.reduce(a, dim, maximum)
53
def minimum_with_index(a_value, a_index, b_value, b_index):
54
mask = a_value < b_value
55
equal = a_value == b_value
56
if is_floating(a_value):
57
a_isnan = a_value != a_value
58
b_isnan = b_value != b_value
59
mask |= a_isnan and not b_isnan
61
equal |= a_isnan and b_isnan
64
mask |= equal & (a_index < b_index)
65
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
69
def maximum_with_index(a_value, a_index, b_value, b_index):
70
mask = a_value > b_value
71
equal = a_value == b_value
72
if is_floating(a_value):
73
a_isnan = a_value != a_value
74
b_isnan = b_value != b_value
75
mask |= a_isnan and not b_isnan
77
equal |= a_isnan and b_isnan
80
mask |= equal & (a_index < b_index)
81
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
85
def min_with_index(value, index, dim):
86
return tl.reduce((value, index), dim, minimum_with_index)
90
def max_with_index(value, index, dim):
91
return tl.reduce((value, index), dim, maximum_with_index)
95
def welford_reduce(value, mean, m2, weight):
97
new_weight = weight + 1
98
new_mean = mean + delta / new_weight
101
m2 + delta * (value - new_mean),
107
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
108
delta = mean_2 - mean_1
109
new_weight = weight_1 + weight_2
110
w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
112
mean_1 + delta * w2_over_w,
113
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
119
def welford(mean, m2, weight, dim):
120
return tl.reduce((mean, m2, weight), dim, welford_combine)
124
def device_assert_then(cond, msg, r):
125
tl.device_assert(cond, msg)
130
def randint64(seed, offset, low, high):
131
r0, r1, r2, r3 = tl.randint4x(seed, offset)
132
r0 = r0.to(tl.uint64)
133
r1 = r1.to(tl.uint64)
134
result = r0 | (r1 << 32)
136
result = result % size.to(tl.uint64)
137
result = result.to(tl.int64) + low
142
def _any_combine(a, b):
148
return tl.reduce(a, dim, _any_combine)
152
def bucketize_binary_search(
161
See [Note: Inductor bucketize op]
164
low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
165
high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
167
full_range = OFFSETS_SIZE + 1
168
while full_range > 1:
169
mid = (high + low) // 2
170
mask = mid < OFFSETS_SIZE
171
bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)
173
is_above = values >= bucket_upper_bound
175
is_above = values > bucket_upper_bound
177
low = tl.where(is_above & mask, mid + 1, low)
178
high = tl.where(is_above, high, mid)
180
full_range = (full_range + 1) // 2
189
DTYPE_VALUE_AS_UINT: tl.constexpr,
190
DTYPE_PACK: tl.constexpr,
193
DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
194
bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
195
uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)
196
return flag.to(DTYPE_PACK) | (uv << bitwidth)
206
DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)
207
DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
208
bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
209
value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)
210
return value_uint.to(DTYPE_VALUE, bitcast=True)
214
def unpack_flag(pack, DTYPE_FLAG):
215
return pack.to(DTYPE_FLAG)
219
def exclusive_scan_decoupled_lookback(
225
DTYPE_VALUE_AS_UINT: tl.constexpr,
226
DTYPE_PACK: tl.constexpr,
228
"""Compute exclusive scan of a scalar value between blocks
230
Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
232
scratch_base: Pointer to scratch space in global memory
233
block_value: Scalar value for this block
234
index: Scalar index of this block relative to the current scan
235
combine_fn: Function ``(value, value) -> value`` which is scanned over
236
init: Scalar value equal to the identiy of combine_fn
237
DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``
238
DTYPE_PACK: Unsigned type twice the width of block_value
240
NOTE: This function is limited to values which are 32-bits or less.
242
DTYPE_VALUE = block_value.dtype
243
pack = pack_value_flag(
245
tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),
249
tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
251
exclusive_prefix = init
252
test_target = index - 1
253
while test_target >= 0:
255
flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)
257
pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed")
258
flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)
260
value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)
261
exclusive_prefix = combine_fn(value, exclusive_prefix)
266
test_target = test_target - 1
269
inclusive_prefix = combine_fn(exclusive_prefix, block_value)
270
pack = pack_value_flag(
272
tl.full([], 2, DTYPE_VALUE_AS_UINT),
276
tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
277
return exclusive_prefix
281
def exclusive_scan_decoupled_lookback_64(
282
scratch_base, block_value, index, combine_fn, init
284
"""Compute exclusive scan of a scalar value between blocks
286
Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
288
scratch_base: Pointer to scratch space in global memory
289
block_value: Scalar value for this block, must be 64-bits wide
290
index: Scalar index of this block relative to the current scan
291
combine_fn: Function ``(value, value) -> value`` which is scanned over
292
init: Scalar value equal to the identiy of combine_fn
294
block_value_u64 = block_value.to(tl.uint64, bitcast=True)
295
tl.store(scratch_base + 3 * index + 1, block_value_u64)
297
flag_one = tl.full([], 1, tl.uint64)
298
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release")
300
exclusive_prefix = init
301
test_target = index - 1
302
while test_target >= 0:
303
flag = tl.full([], 0, tl.uint64)
305
flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire")
307
value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))
308
value = value_u64.to(block_value.dtype, bitcast=True)
309
exclusive_prefix = combine_fn(value, exclusive_prefix)
314
test_target = test_target - 1
317
inclusive_prefix = combine_fn(exclusive_prefix, block_value)
318
inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)
319
tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)
321
flag_two = tl.full([], 2, tl.uint64)
322
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
324
return exclusive_prefix
330
y = tl.math.ilogb(x) + 1
331
exponent = tl.where(x == 0, 0, y)
332
mantissa = tl.where(x == 0, 0, tl.math.ldexp(x, -y))
333
return mantissa, exponent