ClickHouse

Форк
0
/
AggregateFunctionIf.cpp 
505 строк · 19.7 Кб
1
#include "AggregateFunctionCombinatorFactory.h"
2
#include "AggregateFunctionIf.h"
3
#include "AggregateFunctionNull.h"
4

5
namespace DB
6
{
7

8
namespace ErrorCodes
9
{
10
    extern const int LOGICAL_ERROR;
11
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
12
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
13
}
14

15
class AggregateFunctionCombinatorIf final : public IAggregateFunctionCombinator
16
{
17
public:
18
    String getName() const override { return "If"; }
19

20
    DataTypes transformArguments(const DataTypes & arguments) const override
21
    {
22
        if (arguments.empty())
23
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
24
                "Incorrect number of arguments for aggregate function with {} suffix", getName());
25

26
        if (!isUInt8(arguments.back()) && !arguments.back()->onlyNull())
27
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of last argument for "
28
                            "aggregate function with {} suffix", arguments.back()->getName(), getName());
29

30
        return DataTypes(arguments.begin(), std::prev(arguments.end()));
31
    }
32

33
    AggregateFunctionPtr transformAggregateFunction(
34
        const AggregateFunctionPtr & nested_function,
35
        const AggregateFunctionProperties &,
36
        const DataTypes & arguments,
37
        const Array & params) const override
38
    {
39
        return std::make_shared<AggregateFunctionIf>(nested_function, arguments, params);
40
    }
41
};
42

43

44
/** There are two cases: for single argument and variadic.
45
  * Code for single argument is much more efficient.
46
  */
47
template <bool result_is_nullable, bool serialize_flag>
48
class AggregateFunctionIfNullUnary final
49
    : public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
50
        AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>>
51
{
52
private:
53
    size_t num_arguments;
54
    bool filter_is_nullable = false;
55
    bool filter_is_only_null = false;
56

57
    /// The name of the nested function, including combinators (i.e. *If)
58
    ///
59
    /// getName() from the nested_function cannot be used because in case of *If combinator
60
    /// with Nullable argument nested_function will point to the function without combinator.
61
    /// (I.e. sumIf(Nullable, 1) -> sum()), and distributed query processing will fail.
62
    ///
63
    /// And nested_function cannot point to the function with *If since
64
    /// due to optimization in the add() which pass only one column with the result,
65
    /// and so AggregateFunctionIf::add() cannot be called this way
66
    /// (it write to the last argument -- num_arguments-1).
67
    ///
68
    /// And to avoid extra level of indirection, the name of function is cached:
69
    ///
70
    ///     AggregateFunctionIfNullUnary::add -> [ AggregateFunctionIf::add -> ] AggregateFunctionSum::add
71
    String name;
72

73
    using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
74
        AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>>;
75

76
    inline bool singleFilter(const IColumn ** columns, size_t row_num) const
77
    {
78
        const IColumn * filter_column = columns[num_arguments - 1];
79

80
        if (filter_is_nullable)
81
        {
82
            const ColumnNullable * nullable_column = assert_cast<const ColumnNullable *>(filter_column);
83
            filter_column = nullable_column->getNestedColumnPtr().get();
84
            const UInt8 * filter_null_map = nullable_column->getNullMapData().data();
85

86
            return assert_cast<const ColumnUInt8 &>(*filter_column).getData()[row_num] && !filter_null_map[row_num];
87
        }
88

89
        return assert_cast<const ColumnUInt8 &>(*filter_column).getData()[row_num];
90
    }
91

92
public:
93
    String getName() const override
94
    {
95
        return name;
96
    }
97

98
    AggregateFunctionIfNullUnary(const String & name_, AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
99
        : Base(std::move(nested_function_), arguments, params)
100
        , num_arguments(arguments.size())
101
        , name(name_)
102
    {
103
        if (num_arguments == 0)
104
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
105
                "Aggregate function {} require at least one argument", getName());
106

107
        filter_is_nullable = arguments[num_arguments - 1]->isNullable();
108
        filter_is_only_null = arguments[num_arguments - 1]->onlyNull();
109
    }
110

111
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
112
    {
113
        if (filter_is_only_null)
114
            return;
115

116
        const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
117
        const IColumn * nested_column = &column->getNestedColumn();
118
        if (!column->isNullAt(row_num) && singleFilter(columns, row_num))
119
        {
120
            this->setFlag(place);
121
            this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
122
        }
123
    }
124

125
    void addBatchSinglePlace(
126
        size_t row_begin,
127
        size_t row_end,
128
        AggregateDataPtr __restrict place,
129
        const IColumn ** columns,
130
        Arena * arena,
131
        ssize_t) const override
132
    {
133
        if (filter_is_only_null)
134
            return;
135

136
        const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
137
        const UInt8 * null_map = column->getNullMapData().data();
138
        const IColumn * columns_param[] = {&column->getNestedColumn()};
139

140
        const IColumn * filter_column = columns[num_arguments - 1];
141

142
        const UInt8 * filter_values = nullptr;
143
        const UInt8 * filter_null_map = nullptr;
144

145
        if (filter_is_nullable)
146
        {
147
            const ColumnNullable * nullable_column = assert_cast<const ColumnNullable *>(filter_column);
148
            filter_column = nullable_column->getNestedColumnPtr().get();
149
            filter_null_map = nullable_column->getNullMapData().data();
150
        }
151

152
        filter_values = assert_cast<const ColumnUInt8 *>(filter_column)->getData().data();
153

154
        /// Combine the 2 flag arrays so we can call a simplified version (one check vs 2)
155
        /// Note that now the null map will contain 0 if not null and not filtered, or 1 for null or filtered (or both)
156

157
        auto final_nulls = std::make_unique<UInt8[]>(row_end);
158

159
        if (filter_null_map)
160
            for (size_t i = row_begin; i < row_end; ++i)
161
                final_nulls[i] = (!!null_map[i]) | (!filter_values[i]) | (!!filter_null_map[i]);
162
        else
163
            for (size_t i = row_begin; i < row_end; ++i)
164
                final_nulls[i] = (!!null_map[i]) | (!filter_values[i]);
165

166
        if constexpr (result_is_nullable)
167
        {
168
            if (!memoryIsByte(final_nulls.get(), row_begin, row_end, 1))
169
                this->setFlag(place);
170
            else
171
                return; /// No work to do.
172
        }
173

174
        this->nested_function->addBatchSinglePlaceNotNull(
175
            row_begin,
176
            row_end,
177
            this->nestedPlace(place),
178
            columns_param,
179
            final_nulls.get(),
180
            arena,
181
            -1);
182
    }
183

184
#if USE_EMBEDDED_COMPILER
185

186
    bool isCompilable() const override
187
    {
188
        return canBeNativeType(*this->argument_types.back()) && this->nested_function->isCompilable();
189
    }
190

191
    void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
192
    {
193
        llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
194

195
        const auto & nullable_type = arguments[0].type;
196
        const auto & nullable_value = arguments[0].value;
197

198
        auto * wrapped_value = b.CreateExtractValue(nullable_value, {0});
199
        auto * is_null_value = b.CreateExtractValue(nullable_value, {1});
200

201
        const auto & predicate_type = arguments.back().type;
202
        auto * predicate_value = arguments.back().value;
203
        auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);
204

205
        auto * head = b.GetInsertBlock();
206

207
        auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
208
        auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
209
        auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
210

211
        b.CreateCondBr(b.CreateAnd(b.CreateNot(is_null_value), is_predicate_true), if_not_null, if_null);
212

213
        b.SetInsertPoint(if_null);
214
        b.CreateBr(join_block);
215

216
        b.SetInsertPoint(if_not_null);
217

218
        if constexpr (result_is_nullable)
219
            b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
220

221
        auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, this->prefix_size);
222
        this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { ValueWithType(wrapped_value, removeNullable(nullable_type)) });
223
        b.CreateBr(join_block);
224

225
        b.SetInsertPoint(join_block);
226
    }
227

228
#endif
229

230
};
231

232
template <bool result_is_nullable, bool serialize_flag>
233
class AggregateFunctionIfNullVariadic final : public AggregateFunctionNullBase<
234
                                                  result_is_nullable,
235
                                                  serialize_flag,
236
                                                  AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag>>
237
{
238
private:
239
    bool filter_is_only_null = false;
240

241
public:
242

243
    String getName() const override
244
    {
245
        return Base::getName();
246
    }
247

248
    AggregateFunctionIfNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
249
        : Base(std::move(nested_function_), arguments, params), number_of_arguments(arguments.size())
250
    {
251
        if (number_of_arguments == 1)
252
            throw Exception(ErrorCodes::LOGICAL_ERROR, "Single argument is passed to AggregateFunctionIfNullVariadic");
253

254
        if (number_of_arguments > MAX_ARGS)
255
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
256
                "Maximum number of arguments for aggregate function with Nullable types is {}", toString(MAX_ARGS));
257

258
        for (size_t i = 0; i < number_of_arguments; ++i)
259
            is_nullable[i] = arguments[i]->isNullable();
260

261
        filter_is_only_null = arguments.back()->onlyNull();
262
    }
263

264
    static inline bool singleFilter(const IColumn ** columns, size_t row_num, size_t num_arguments)
265
    {
266
        return assert_cast<const ColumnUInt8 &>(*columns[num_arguments - 1]).getData()[row_num];
267
    }
268

269
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
270
    {
271
        /// This container stores the columns we really pass to the nested function.
272
        const IColumn * nested_columns[number_of_arguments];
273

274
        for (size_t i = 0; i < number_of_arguments; ++i)
275
        {
276
            if (is_nullable[i])
277
            {
278
                const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]);
279
                if (nullable_col.isNullAt(row_num))
280
                {
281
                    /// If at least one column has a null value in the current row,
282
                    /// we don't process this row.
283
                    return;
284
                }
285
                nested_columns[i] = &nullable_col.getNestedColumn();
286
            }
287
            else
288
                nested_columns[i] = columns[i];
289
        }
290

291
        if (singleFilter(nested_columns, row_num, number_of_arguments))
292
        {
293
            this->setFlag(place);
294
            this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
295
        }
296
    }
297

298
    void addBatchSinglePlace(
299
        size_t row_begin, size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, Arena * arena, ssize_t) const final
300
    {
301
        if (filter_is_only_null)
302
            return;
303

304
        std::unique_ptr<UInt8[]> final_null_flags = std::make_unique<UInt8[]>(row_end);
305
        const size_t filter_column_num = number_of_arguments - 1;
306

307
        if (is_nullable[filter_column_num])
308
        {
309
            const ColumnNullable * nullable_column = assert_cast<const ColumnNullable *>(columns[filter_column_num]);
310
            const IColumn & filter_column = nullable_column->getNestedColumn();
311
            const UInt8 * filter_null_map = nullable_column->getNullMapColumn().getData().data();
312
            const UInt8 * filter_values = assert_cast<const ColumnUInt8 &>(filter_column).getData().data();
313

314
            for (size_t i = row_begin; i < row_end; i++)
315
            {
316
                final_null_flags[i] = filter_null_map[i] || !filter_values[i];
317
            }
318
        }
319
        else
320
        {
321
            const IColumn * filter_column = columns[filter_column_num];
322
            const UInt8 * filter_values = assert_cast<const ColumnUInt8 *>(filter_column)->getData().data();
323
            for (size_t i = row_begin; i < row_end; i++)
324
                final_null_flags[i] = !filter_values[i];
325
        }
326

327
        const IColumn * nested_columns[number_of_arguments];
328
        for (size_t arg = 0; arg < number_of_arguments; arg++)
329
        {
330
            if (is_nullable[arg])
331
            {
332
                const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[arg]);
333
                if (arg != filter_column_num)
334
                {
335
                    const ColumnUInt8 & nullmap_column = nullable_col.getNullMapColumn();
336
                    const UInt8 * col_null_map = nullmap_column.getData().data();
337
                    for (size_t r = row_begin; r < row_end; r++)
338
                    {
339
                        final_null_flags[r] |= col_null_map[r];
340
                    }
341
                }
342
                nested_columns[arg] = &nullable_col.getNestedColumn();
343
            }
344
            else
345
                nested_columns[arg] = columns[arg];
346
        }
347

348
        bool at_least_one = false;
349
        for (size_t i = row_begin; i < row_end; i++)
350
        {
351
            if (!final_null_flags[i])
352
            {
353
                at_least_one = true;
354
                break;
355
            }
356
        }
357

358
        if (at_least_one)
359
        {
360
            this->setFlag(place);
361
            this->nested_function->addBatchSinglePlaceNotNull(
362
                row_begin, row_end, this->nestedPlace(place), nested_columns, final_null_flags.get(), arena, -1);
363
        }
364
    }
365

366
#if USE_EMBEDDED_COMPILER
367

368
    bool isCompilable() const override
369
    {
370
        return canBeNativeType(*this->argument_types.back()) && this->nested_function->isCompilable();
371
    }
372

373
    void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
374
    {
375
        llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
376

377
        size_t arguments_size = arguments.size();
378

379
        ValuesWithType wrapped_arguments;
380
        wrapped_arguments.reserve(arguments_size);
381

382
        std::vector<llvm::Value * > is_null_values;
383

384
        for (size_t i = 0; i < arguments_size; ++i)
385
        {
386
            const auto & argument_value = arguments[i].value;
387
            const auto & argument_type = arguments[i].type;
388

389
            if (is_nullable[i])
390
            {
391
                auto * wrapped_value = b.CreateExtractValue(argument_value, {0});
392
                is_null_values.emplace_back(b.CreateExtractValue(argument_value, {1}));
393
                wrapped_arguments.emplace_back(wrapped_value, removeNullable(argument_type));
394
            }
395
            else
396
            {
397
                wrapped_arguments.emplace_back(argument_value, argument_type);
398
            }
399
        }
400

401
        auto * head = b.GetInsertBlock();
402

403
        auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
404
        auto * join_block_after_null_checks = llvm::BasicBlock::Create(head->getContext(), "join_block_after_null_checks", head->getParent());
405

406
        auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
407
        b.CreateStore(b.getInt1(false), values_have_null_ptr);
408

409
        for (auto * is_null_value : is_null_values)
410
        {
411
            auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
412
            b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
413
        }
414

415
        b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), join_block, join_block_after_null_checks);
416

417
        b.SetInsertPoint(join_block_after_null_checks);
418

419
        const auto & predicate_type = arguments.back().type;
420
        auto * predicate_value = arguments.back().value;
421
        auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);
422

423
        auto * if_true = llvm::BasicBlock::Create(head->getContext(), "if_true", head->getParent());
424
        auto * if_false = llvm::BasicBlock::Create(head->getContext(), "if_false", head->getParent());
425

426
        b.CreateCondBr(is_predicate_true, if_true, if_false);
427

428
        b.SetInsertPoint(if_false);
429
        b.CreateBr(join_block);
430

431
        b.SetInsertPoint(if_true);
432

433
        if constexpr (result_is_nullable)
434
            b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
435

436
        auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, this->prefix_size);
437
        this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, wrapped_arguments);
438
        b.CreateBr(join_block);
439

440
        b.SetInsertPoint(join_block);
441
    }
442

443
#endif
444

445
private:
446
    using Base = AggregateFunctionNullBase<
447
        result_is_nullable,
448
        serialize_flag,
449
        AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag>>;
450

451
    static constexpr size_t MAX_ARGS = 8;
452
    size_t number_of_arguments = 0;
453
    std::array<char, MAX_ARGS> is_nullable;    /// Plain array is better than std::vector due to one indirection less.
454
};
455

456

457
AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter(
458
    const AggregateFunctionPtr & nested_function, const DataTypes & arguments,
459
    const Array & params, const AggregateFunctionProperties & properties) const
460
{
461
    assert(!arguments.empty());
462

463
    /// Nullability of the last argument (condition) does not affect the nullability of the result (NULL is processed as false).
464
    /// For other arguments it is as usual (at least one is NULL then the result is NULL if possible).
465
    bool return_type_is_nullable = !properties.returns_default_when_only_null && getResultType()->canBeInsideNullable()
466
        && std::any_of(arguments.begin(), arguments.end() - 1, [](const auto & element) { return element->isNullable(); });
467

468
    bool need_to_serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
469

470
    if (arguments.size() <= 2 && arguments.front()->isNullable())
471
    {
472
        if (return_type_is_nullable)
473
        {
474
            return std::make_shared<AggregateFunctionIfNullUnary<true, true>>(nested_function->getName(), nested_func, arguments, params);
475
        }
476
        else
477
        {
478
            if (need_to_serialize_flag)
479
                return std::make_shared<AggregateFunctionIfNullUnary<false, true>>(nested_function->getName(), nested_func, arguments, params);
480
            else
481
                return std::make_shared<AggregateFunctionIfNullUnary<false, false>>(nested_function->getName(), nested_func, arguments, params);
482
        }
483
    }
484
    else
485
    {
486
        if (return_type_is_nullable)
487
        {
488
            return std::make_shared<AggregateFunctionIfNullVariadic<true, true>>(nested_function, arguments, params);
489
        }
490
        else
491
        {
492
            if (need_to_serialize_flag)
493
                return std::make_shared<AggregateFunctionIfNullVariadic<false, true>>(nested_function, arguments, params);
494
            else
495
                return std::make_shared<AggregateFunctionIfNullVariadic<false, false>>(nested_function, arguments, params);
496
        }
497
    }
498
}
499

500
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory & factory)
501
{
502
    factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorIf>());
503
}
504

505
}
506

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.