ClickHouse
505 строк · 19.7 Кб
1#include "AggregateFunctionCombinatorFactory.h"2#include "AggregateFunctionIf.h"3#include "AggregateFunctionNull.h"4
5namespace DB6{
7
8namespace ErrorCodes9{
10extern const int LOGICAL_ERROR;11extern const int ILLEGAL_TYPE_OF_ARGUMENT;12extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;13}
14
15class AggregateFunctionCombinatorIf final : public IAggregateFunctionCombinator16{
17public:18String getName() const override { return "If"; }19
20DataTypes transformArguments(const DataTypes & arguments) const override21{22if (arguments.empty())23throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,24"Incorrect number of arguments for aggregate function with {} suffix", getName());25
26if (!isUInt8(arguments.back()) && !arguments.back()->onlyNull())27throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of last argument for "28"aggregate function with {} suffix", arguments.back()->getName(), getName());29
30return DataTypes(arguments.begin(), std::prev(arguments.end()));31}32
33AggregateFunctionPtr transformAggregateFunction(34const AggregateFunctionPtr & nested_function,35const AggregateFunctionProperties &,36const DataTypes & arguments,37const Array & params) const override38{39return std::make_shared<AggregateFunctionIf>(nested_function, arguments, params);40}41};42
43
44/** There are two cases: for single argument and variadic.
45* Code for single argument is much more efficient.
46*/
47template <bool result_is_nullable, bool serialize_flag>48class AggregateFunctionIfNullUnary final49: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,50AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>>51{
52private:53size_t num_arguments;54bool filter_is_nullable = false;55bool filter_is_only_null = false;56
57/// The name of the nested function, including combinators (i.e. *If)58///59/// getName() from the nested_function cannot be used because in case of *If combinator60/// with Nullable argument nested_function will point to the function without combinator.61/// (I.e. sumIf(Nullable, 1) -> sum()), and distributed query processing will fail.62///63/// And nested_function cannot point to the function with *If since64/// due to optimization in the add() which pass only one column with the result,65/// and so AggregateFunctionIf::add() cannot be called this way66/// (it write to the last argument -- num_arguments-1).67///68/// And to avoid extra level of indirection, the name of function is cached:69///70/// AggregateFunctionIfNullUnary::add -> [ AggregateFunctionIf::add -> ] AggregateFunctionSum::add71String name;72
73using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,74AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>>;75
76inline bool singleFilter(const IColumn ** columns, size_t row_num) const77{78const IColumn * filter_column = columns[num_arguments - 1];79
80if (filter_is_nullable)81{82const ColumnNullable * nullable_column = assert_cast<const ColumnNullable *>(filter_column);83filter_column = nullable_column->getNestedColumnPtr().get();84const UInt8 * filter_null_map = nullable_column->getNullMapData().data();85
86return assert_cast<const ColumnUInt8 &>(*filter_column).getData()[row_num] && !filter_null_map[row_num];87}88
89return assert_cast<const ColumnUInt8 &>(*filter_column).getData()[row_num];90}91
92public:93String getName() const override94{95return name;96}97
98AggregateFunctionIfNullUnary(const String & name_, AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)99: Base(std::move(nested_function_), arguments, params)100, num_arguments(arguments.size())101, name(name_)102{103if (num_arguments == 0)104throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,105"Aggregate function {} require at least one argument", getName());106
107filter_is_nullable = arguments[num_arguments - 1]->isNullable();108filter_is_only_null = arguments[num_arguments - 1]->onlyNull();109}110
111void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override112{113if (filter_is_only_null)114return;115
116const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);117const IColumn * nested_column = &column->getNestedColumn();118if (!column->isNullAt(row_num) && singleFilter(columns, row_num))119{120this->setFlag(place);121this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);122}123}124
125void addBatchSinglePlace(126size_t row_begin,127size_t row_end,128AggregateDataPtr __restrict place,129const IColumn ** columns,130Arena * arena,131ssize_t) const override132{133if (filter_is_only_null)134return;135
136const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);137const UInt8 * null_map = column->getNullMapData().data();138const IColumn * columns_param[] = {&column->getNestedColumn()};139
140const IColumn * filter_column = columns[num_arguments - 1];141
142const UInt8 * filter_values = nullptr;143const UInt8 * filter_null_map = nullptr;144
145if (filter_is_nullable)146{147const ColumnNullable * nullable_column = assert_cast<const ColumnNullable *>(filter_column);148filter_column = nullable_column->getNestedColumnPtr().get();149filter_null_map = nullable_column->getNullMapData().data();150}151
152filter_values = assert_cast<const ColumnUInt8 *>(filter_column)->getData().data();153
154/// Combine the 2 flag arrays so we can call a simplified version (one check vs 2)155/// Note that now the null map will contain 0 if not null and not filtered, or 1 for null or filtered (or both)156
157auto final_nulls = std::make_unique<UInt8[]>(row_end);158
159if (filter_null_map)160for (size_t i = row_begin; i < row_end; ++i)161final_nulls[i] = (!!null_map[i]) | (!filter_values[i]) | (!!filter_null_map[i]);162else163for (size_t i = row_begin; i < row_end; ++i)164final_nulls[i] = (!!null_map[i]) | (!filter_values[i]);165
166if constexpr (result_is_nullable)167{168if (!memoryIsByte(final_nulls.get(), row_begin, row_end, 1))169this->setFlag(place);170else171return; /// No work to do.172}173
174this->nested_function->addBatchSinglePlaceNotNull(175row_begin,176row_end,177this->nestedPlace(place),178columns_param,179final_nulls.get(),180arena,181-1);182}183
184#if USE_EMBEDDED_COMPILER185
186bool isCompilable() const override187{188return canBeNativeType(*this->argument_types.back()) && this->nested_function->isCompilable();189}190
191void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override192{193llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);194
195const auto & nullable_type = arguments[0].type;196const auto & nullable_value = arguments[0].value;197
198auto * wrapped_value = b.CreateExtractValue(nullable_value, {0});199auto * is_null_value = b.CreateExtractValue(nullable_value, {1});200
201const auto & predicate_type = arguments.back().type;202auto * predicate_value = arguments.back().value;203auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);204
205auto * head = b.GetInsertBlock();206
207auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());208auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());209auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());210
211b.CreateCondBr(b.CreateAnd(b.CreateNot(is_null_value), is_predicate_true), if_not_null, if_null);212
213b.SetInsertPoint(if_null);214b.CreateBr(join_block);215
216b.SetInsertPoint(if_not_null);217
218if constexpr (result_is_nullable)219b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);220
221auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, this->prefix_size);222this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { ValueWithType(wrapped_value, removeNullable(nullable_type)) });223b.CreateBr(join_block);224
225b.SetInsertPoint(join_block);226}227
228#endif229
230};231
232template <bool result_is_nullable, bool serialize_flag>233class AggregateFunctionIfNullVariadic final : public AggregateFunctionNullBase<234result_is_nullable,235serialize_flag,236AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag>>237{
238private:239bool filter_is_only_null = false;240
241public:242
243String getName() const override244{245return Base::getName();246}247
248AggregateFunctionIfNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)249: Base(std::move(nested_function_), arguments, params), number_of_arguments(arguments.size())250{251if (number_of_arguments == 1)252throw Exception(ErrorCodes::LOGICAL_ERROR, "Single argument is passed to AggregateFunctionIfNullVariadic");253
254if (number_of_arguments > MAX_ARGS)255throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,256"Maximum number of arguments for aggregate function with Nullable types is {}", toString(MAX_ARGS));257
258for (size_t i = 0; i < number_of_arguments; ++i)259is_nullable[i] = arguments[i]->isNullable();260
261filter_is_only_null = arguments.back()->onlyNull();262}263
264static inline bool singleFilter(const IColumn ** columns, size_t row_num, size_t num_arguments)265{266return assert_cast<const ColumnUInt8 &>(*columns[num_arguments - 1]).getData()[row_num];267}268
269void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override270{271/// This container stores the columns we really pass to the nested function.272const IColumn * nested_columns[number_of_arguments];273
274for (size_t i = 0; i < number_of_arguments; ++i)275{276if (is_nullable[i])277{278const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[i]);279if (nullable_col.isNullAt(row_num))280{281/// If at least one column has a null value in the current row,282/// we don't process this row.283return;284}285nested_columns[i] = &nullable_col.getNestedColumn();286}287else288nested_columns[i] = columns[i];289}290
291if (singleFilter(nested_columns, row_num, number_of_arguments))292{293this->setFlag(place);294this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);295}296}297
298void addBatchSinglePlace(299size_t row_begin, size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, Arena * arena, ssize_t) const final300{301if (filter_is_only_null)302return;303
304std::unique_ptr<UInt8[]> final_null_flags = std::make_unique<UInt8[]>(row_end);305const size_t filter_column_num = number_of_arguments - 1;306
307if (is_nullable[filter_column_num])308{309const ColumnNullable * nullable_column = assert_cast<const ColumnNullable *>(columns[filter_column_num]);310const IColumn & filter_column = nullable_column->getNestedColumn();311const UInt8 * filter_null_map = nullable_column->getNullMapColumn().getData().data();312const UInt8 * filter_values = assert_cast<const ColumnUInt8 &>(filter_column).getData().data();313
314for (size_t i = row_begin; i < row_end; i++)315{316final_null_flags[i] = filter_null_map[i] || !filter_values[i];317}318}319else320{321const IColumn * filter_column = columns[filter_column_num];322const UInt8 * filter_values = assert_cast<const ColumnUInt8 *>(filter_column)->getData().data();323for (size_t i = row_begin; i < row_end; i++)324final_null_flags[i] = !filter_values[i];325}326
327const IColumn * nested_columns[number_of_arguments];328for (size_t arg = 0; arg < number_of_arguments; arg++)329{330if (is_nullable[arg])331{332const ColumnNullable & nullable_col = assert_cast<const ColumnNullable &>(*columns[arg]);333if (arg != filter_column_num)334{335const ColumnUInt8 & nullmap_column = nullable_col.getNullMapColumn();336const UInt8 * col_null_map = nullmap_column.getData().data();337for (size_t r = row_begin; r < row_end; r++)338{339final_null_flags[r] |= col_null_map[r];340}341}342nested_columns[arg] = &nullable_col.getNestedColumn();343}344else345nested_columns[arg] = columns[arg];346}347
348bool at_least_one = false;349for (size_t i = row_begin; i < row_end; i++)350{351if (!final_null_flags[i])352{353at_least_one = true;354break;355}356}357
358if (at_least_one)359{360this->setFlag(place);361this->nested_function->addBatchSinglePlaceNotNull(362row_begin, row_end, this->nestedPlace(place), nested_columns, final_null_flags.get(), arena, -1);363}364}365
366#if USE_EMBEDDED_COMPILER367
368bool isCompilable() const override369{370return canBeNativeType(*this->argument_types.back()) && this->nested_function->isCompilable();371}372
373void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override374{375llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);376
377size_t arguments_size = arguments.size();378
379ValuesWithType wrapped_arguments;380wrapped_arguments.reserve(arguments_size);381
382std::vector<llvm::Value * > is_null_values;383
384for (size_t i = 0; i < arguments_size; ++i)385{386const auto & argument_value = arguments[i].value;387const auto & argument_type = arguments[i].type;388
389if (is_nullable[i])390{391auto * wrapped_value = b.CreateExtractValue(argument_value, {0});392is_null_values.emplace_back(b.CreateExtractValue(argument_value, {1}));393wrapped_arguments.emplace_back(wrapped_value, removeNullable(argument_type));394}395else396{397wrapped_arguments.emplace_back(argument_value, argument_type);398}399}400
401auto * head = b.GetInsertBlock();402
403auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());404auto * join_block_after_null_checks = llvm::BasicBlock::Create(head->getContext(), "join_block_after_null_checks", head->getParent());405
406auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());407b.CreateStore(b.getInt1(false), values_have_null_ptr);408
409for (auto * is_null_value : is_null_values)410{411auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);412b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);413}414
415b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), join_block, join_block_after_null_checks);416
417b.SetInsertPoint(join_block_after_null_checks);418
419const auto & predicate_type = arguments.back().type;420auto * predicate_value = arguments.back().value;421auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);422
423auto * if_true = llvm::BasicBlock::Create(head->getContext(), "if_true", head->getParent());424auto * if_false = llvm::BasicBlock::Create(head->getContext(), "if_false", head->getParent());425
426b.CreateCondBr(is_predicate_true, if_true, if_false);427
428b.SetInsertPoint(if_false);429b.CreateBr(join_block);430
431b.SetInsertPoint(if_true);432
433if constexpr (result_is_nullable)434b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);435
436auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, this->prefix_size);437this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, wrapped_arguments);438b.CreateBr(join_block);439
440b.SetInsertPoint(join_block);441}442
443#endif444
445private:446using Base = AggregateFunctionNullBase<447result_is_nullable,448serialize_flag,449AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag>>;450
451static constexpr size_t MAX_ARGS = 8;452size_t number_of_arguments = 0;453std::array<char, MAX_ARGS> is_nullable; /// Plain array is better than std::vector due to one indirection less.454};455
456
457AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter(458const AggregateFunctionPtr & nested_function, const DataTypes & arguments,459const Array & params, const AggregateFunctionProperties & properties) const460{
461assert(!arguments.empty());462
463/// Nullability of the last argument (condition) does not affect the nullability of the result (NULL is processed as false).464/// For other arguments it is as usual (at least one is NULL then the result is NULL if possible).465bool return_type_is_nullable = !properties.returns_default_when_only_null && getResultType()->canBeInsideNullable()466&& std::any_of(arguments.begin(), arguments.end() - 1, [](const auto & element) { return element->isNullable(); });467
468bool need_to_serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;469
470if (arguments.size() <= 2 && arguments.front()->isNullable())471{472if (return_type_is_nullable)473{474return std::make_shared<AggregateFunctionIfNullUnary<true, true>>(nested_function->getName(), nested_func, arguments, params);475}476else477{478if (need_to_serialize_flag)479return std::make_shared<AggregateFunctionIfNullUnary<false, true>>(nested_function->getName(), nested_func, arguments, params);480else481return std::make_shared<AggregateFunctionIfNullUnary<false, false>>(nested_function->getName(), nested_func, arguments, params);482}483}484else485{486if (return_type_is_nullable)487{488return std::make_shared<AggregateFunctionIfNullVariadic<true, true>>(nested_function, arguments, params);489}490else491{492if (need_to_serialize_flag)493return std::make_shared<AggregateFunctionIfNullVariadic<false, true>>(nested_function, arguments, params);494else495return std::make_shared<AggregateFunctionIfNullVariadic<false, false>>(nested_function, arguments, params);496}497}498}
499
500void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory & factory)501{
502factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorIf>());503}
504
505}
506