pytorch

Форк
0
/
triton_helpers.py 
333 строки · 9.2 Кб
1
import triton
2
import triton.language as tl
3

4

5
@triton.jit
6
def promote_to_tensor(x):
7
    # Addition promotes to tensor for us
8
    return x + tl.zeros((1,), tl.int1)
9

10

11
@triton.jit
12
def is_floating(x):
13
    return promote_to_tensor(x).dtype.is_floating()
14

15

16
@triton.jit
17
def _prod_accumulate(a, b):
18
    return a * b
19

20

21
@triton.jit
22
def prod(input, axis):
23
    return tl.reduce(input, axis, _prod_accumulate)
24

25

26
@triton.jit
27
def minimum(a, b):
28
    mask = a < b
29
    if is_floating(a):
30
        mask |= a != a
31
    return tl.where(mask, a, b)
32

33

34
@triton.jit
35
def maximum(a, b):
36
    mask = a > b
37
    if is_floating(a):
38
        mask |= a != a
39
    return tl.where(mask, a, b)
40

41

42
@triton.jit
43
def min2(a, dim):
44
    return tl.reduce(a, dim, minimum)
45

46

47
@triton.jit
48
def max2(a, dim):
49
    return tl.reduce(a, dim, maximum)
50

51

52
@triton.jit
53
def minimum_with_index(a_value, a_index, b_value, b_index):
54
    mask = a_value < b_value
55
    equal = a_value == b_value
56
    if is_floating(a_value):
57
        a_isnan = a_value != a_value
58
        b_isnan = b_value != b_value
59
        mask |= a_isnan and not b_isnan
60
        # Consider NaNs as equal
61
        equal |= a_isnan and b_isnan
62

63
    # Prefer lowest index if values are equal
64
    mask |= equal & (a_index < b_index)
65
    return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
66

67

68
@triton.jit
69
def maximum_with_index(a_value, a_index, b_value, b_index):
70
    mask = a_value > b_value
71
    equal = a_value == b_value
72
    if is_floating(a_value):
73
        a_isnan = a_value != a_value
74
        b_isnan = b_value != b_value
75
        mask |= a_isnan and not b_isnan
76
        # Consider NaNs as equal
77
        equal |= a_isnan and b_isnan
78

79
    # Prefer lowest index if values are equal
80
    mask |= equal & (a_index < b_index)
81
    return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
82

83

84
@triton.jit
85
def min_with_index(value, index, dim):
86
    return tl.reduce((value, index), dim, minimum_with_index)
87

88

89
@triton.jit
90
def max_with_index(value, index, dim):
91
    return tl.reduce((value, index), dim, maximum_with_index)
92

93

94
@triton.jit
95
def welford_reduce(value, mean, m2, weight):
96
    delta = value - mean
97
    new_weight = weight + 1
98
    new_mean = mean + delta / new_weight
99
    return (
100
        new_mean,
101
        m2 + delta * (value - new_mean),
102
        new_weight,
103
    )
104

105

106
@triton.jit
107
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
108
    delta = mean_2 - mean_1
109
    new_weight = weight_1 + weight_2
110
    w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
111
    return (
112
        mean_1 + delta * w2_over_w,
113
        m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
114
        new_weight,
115
    )
116

117

118
@triton.jit
119
def welford(mean, m2, weight, dim):
120
    return tl.reduce((mean, m2, weight), dim, welford_combine)
121

122

123
@triton.jit
124
def device_assert_then(cond, msg, r):
125
    tl.device_assert(cond, msg)
126
    return r
127

128

129
@triton.jit
130
def randint64(seed, offset, low, high):
131
    r0, r1, r2, r3 = tl.randint4x(seed, offset)
132
    r0 = r0.to(tl.uint64)
133
    r1 = r1.to(tl.uint64)
134
    result = r0 | (r1 << 32)
135
    size = high - low
136
    result = result % size.to(tl.uint64)
137
    result = result.to(tl.int64) + low
138
    return result
139

140

141
@triton.jit
142
def _any_combine(a, b):
143
    return a | b
144

145

146
@triton.jit
147
def any(a, dim):
148
    return tl.reduce(a, dim, _any_combine)
149

150

151
@triton.jit
152
def bucketize_binary_search(
153
    values,  # 1D tensor
154
    offsets_ptr,
155
    indexing_dtype,
156
    right,  # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
157
    OFFSETS_SIZE: int,
158
    BLOCK_SHAPE,  # tuple/list of block shape
159
):
160
    """
161
    See [Note: Inductor bucketize op]
162
    """
163

164
    low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
165
    high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
166

167
    full_range = OFFSETS_SIZE + 1
168
    while full_range > 1:
169
        mid = (high + low) // 2
170
        mask = mid < OFFSETS_SIZE
171
        bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)
172
        if right:
173
            is_above = values >= bucket_upper_bound
174
        else:
175
            is_above = values > bucket_upper_bound
176

177
        low = tl.where(is_above & mask, mid + 1, low)
178
        high = tl.where(is_above, high, mid)
179

180
        full_range = (full_range + 1) // 2
181

182
    return low
183

184

185
@triton.jit
186
def pack_value_flag(
187
    value,
188
    flag,
189
    DTYPE_VALUE_AS_UINT: tl.constexpr,
190
    DTYPE_PACK: tl.constexpr,
191
):
192
    # Workaround for triton bug, tensor.to doesn't unwrap constexpr values
193
    DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
194
    bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
195
    uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)
196
    return flag.to(DTYPE_PACK) | (uv << bitwidth)
197

198

199
@triton.jit
200
def unpack_value(
201
    pack,
202
    DTYPE_VALUE,
203
    DTYPE_VALUE_AS_UINT,
204
):
205
    # Workaround for triton bug, tensor.to doesn't unwrap constexpr values
206
    DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)
207
    DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
208
    bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
209
    value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)
210
    return value_uint.to(DTYPE_VALUE, bitcast=True)
211

212

213
@triton.jit
214
def unpack_flag(pack, DTYPE_FLAG):
215
    return pack.to(DTYPE_FLAG)
216

217

218
@triton.jit
219
def exclusive_scan_decoupled_lookback(
220
    scratch_base,
221
    block_value,
222
    index,
223
    combine_fn,
224
    init,
225
    DTYPE_VALUE_AS_UINT: tl.constexpr,
226
    DTYPE_PACK: tl.constexpr,
227
):
228
    """Compute exclusive scan of a scalar value between blocks
229

230
    Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
231

232
    scratch_base: Pointer to scratch space in global memory
233
    block_value: Scalar value for this block
234
    index: Scalar index of this block relative to the current scan
235
    combine_fn: Function ``(value, value) -> value`` which is scanned over
236
    init: Scalar value equal to the identiy of combine_fn
237
    DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``
238
    DTYPE_PACK: Unsigned type twice the width of block_value
239

240
    NOTE: This function is limited to values which are 32-bits or less.
241
    """
242
    DTYPE_VALUE = block_value.dtype
243
    pack = pack_value_flag(
244
        block_value,
245
        tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),
246
        DTYPE_VALUE_AS_UINT,
247
        DTYPE_PACK,
248
    )
249
    tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
250

251
    exclusive_prefix = init
252
    test_target = index - 1
253
    while test_target >= 0:
254
        # tl.atomic_load
255
        flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)
256
        while flag == 0:
257
            pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed")
258
            flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)
259

260
        value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)
261
        exclusive_prefix = combine_fn(value, exclusive_prefix)
262

263
        if flag == 2:
264
            test_target = -1
265
        else:
266
            test_target = test_target - 1
267

268
    # Make inclusive block sum visible to other blocks
269
    inclusive_prefix = combine_fn(exclusive_prefix, block_value)
270
    pack = pack_value_flag(
271
        inclusive_prefix,
272
        tl.full([], 2, DTYPE_VALUE_AS_UINT),
273
        DTYPE_VALUE_AS_UINT,
274
        DTYPE_PACK,
275
    )
276
    tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
277
    return exclusive_prefix
278

279

280
@triton.jit
281
def exclusive_scan_decoupled_lookback_64(
282
    scratch_base, block_value, index, combine_fn, init
283
):
284
    """Compute exclusive scan of a scalar value between blocks
285

286
    Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
287

288
    scratch_base: Pointer to scratch space in global memory
289
    block_value: Scalar value for this block, must be 64-bits wide
290
    index: Scalar index of this block relative to the current scan
291
    combine_fn: Function ``(value, value) -> value`` which is scanned over
292
    init: Scalar value equal to the identiy of combine_fn
293
    """
294
    block_value_u64 = block_value.to(tl.uint64, bitcast=True)
295
    tl.store(scratch_base + 3 * index + 1, block_value_u64)
296
    tl.debug_barrier()
297
    flag_one = tl.full([], 1, tl.uint64)
298
    tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release")
299

300
    exclusive_prefix = init
301
    test_target = index - 1
302
    while test_target >= 0:
303
        flag = tl.full([], 0, tl.uint64)
304
        while flag == 0:
305
            flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire")
306

307
        value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))
308
        value = value_u64.to(block_value.dtype, bitcast=True)
309
        exclusive_prefix = combine_fn(value, exclusive_prefix)
310

311
        if flag == 2:
312
            test_target = -1
313
        else:
314
            test_target = test_target - 1
315

316
    # Make inclusive block sum visible to other blocks
317
    inclusive_prefix = combine_fn(exclusive_prefix, block_value)
318
    inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)
319
    tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)
320
    tl.debug_barrier()
321
    flag_two = tl.full([], 2, tl.uint64)
322
    tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
323

324
    return exclusive_prefix
325

326

327
@triton.jit
328
def frexp(x):
329
    # TODO(isuruf): use inline_asm_elementwise here
330
    y = tl.math.ilogb(x) + 1
331
    exponent = tl.where(x == 0, 0, y)
332
    mantissa = tl.where(x == 0, 0, tl.math.ldexp(x, -y))
333
    return mantissa, exponent
334

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.