ClickHouse
416 строк · 13.9 Кб
1#include <AggregateFunctions/IAggregateFunction.h>2#include <AggregateFunctions/AggregateFunctionFactory.h>3#include <AggregateFunctions/Helpers.h>4#include <AggregateFunctions/FactoryHelpers.h>5
6#include <base/sort.h>7#include <algorithm>8#include <type_traits>9#include <utility>10
11#include <Common/RadixSort.h>12#include <Common/Exception.h>13#include <Common/ArenaAllocator.h>14#include <Common/assert_cast.h>15
16#include <IO/ReadHelpers.h>17#include <IO/WriteHelpers.h>18#include <IO/ReadBufferFromString.h>19#include <IO/WriteBufferFromString.h>20#include <IO/Operators.h>21
22#include <DataTypes/IDataType.h>23#include <DataTypes/DataTypeDate.h>24#include <DataTypes/DataTypeDateTime.h>25#include <DataTypes/DataTypeArray.h>26#include <DataTypes/DataTypeString.h>27#include <DataTypes/DataTypesNumber.h>28#include <Columns/ColumnArray.h>29#include <Columns/ColumnString.h>30#include <Columns/ColumnVector.h>31
32#include <Columns/IColumn.h>33#include <Columns/ColumnConst.h>34
35namespace DB36{
37
38struct Settings;39
40namespace ErrorCodes41{
42extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;43extern const int BAD_ARGUMENTS;44extern const int TOO_LARGE_ARRAY_SIZE;45}
46
47namespace
48{
49
50enum class GroupArraySortedStrategy51{
52heap,53sort
54};55
56constexpr size_t group_array_sorted_sort_strategy_max_elements_threshold = 1000000;57
58template <typename T, GroupArraySortedStrategy strategy>59struct GroupArraySortedData60{
61using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;62using Array = PODArray<T, 32, Allocator>;63
64static constexpr size_t partial_sort_max_elements_factor = 2;65
66static constexpr bool is_value_generic_field = std::is_same_v<T, Field>;67
68Array values;69
70static bool compare(const T & lhs, const T & rhs)71{72if constexpr (is_value_generic_field)73{74return lhs < rhs;75}76else77{78return CompareHelper<T>::less(lhs, rhs, -1);79}80}81
82struct Comparator83{84bool operator()(const T & lhs, const T & rhs)85{86return compare(lhs, rhs);87}88};89
90ALWAYS_INLINE void heapReplaceTop()91{92size_t size = values.size();93if (size < 2)94return;95
96size_t child_index = 1;97
98if (values.size() > 2 && compare(values[1], values[2]))99++child_index;100
101/// Check if we are in order102if (compare(values[child_index], values[0]))103return;104
105size_t current_index = 0;106auto current = values[current_index];107
108do109{110/// We are not in heap-order, swap the parent with it's largest child.111values[current_index] = values[child_index];112current_index = child_index;113
114// Recompute the child based off of the updated parent115child_index = 2 * child_index + 1;116
117if (child_index >= size)118break;119
120if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1]))121{122/// Right child exists and is greater than left child.123++child_index;124}125
126/// Check if we are in order.127} while (!compare(values[child_index], current));128
129values[current_index] = current;130}131
132ALWAYS_INLINE void sortAndLimit(size_t max_elements, Arena * arena)133{134if constexpr (is_value_generic_field)135{136::sort(values.begin(), values.end(), Comparator());137}138else139{140bool try_sort = trySort(values.begin(), values.end(), Comparator());141if (!try_sort)142RadixSort<RadixSortNumTraits<T>>::executeLSD(values.data(), values.size());143}144
145if (values.size() > max_elements)146values.resize(max_elements, arena);147}148
149ALWAYS_INLINE void partialSortAndLimitIfNeeded(size_t max_elements, Arena * arena)150{151if (values.size() < max_elements * partial_sort_max_elements_factor)152return;153
154::nth_element(values.begin(), values.begin() + max_elements, values.end(), Comparator());155values.resize(max_elements, arena);156}157
158ALWAYS_INLINE void addElement(T && element, size_t max_elements, Arena * arena)159{160if constexpr (strategy == GroupArraySortedStrategy::heap)161{162if (values.size() >= max_elements)163{164/// Element is greater or equal than current max element, it cannot be in k min elements165if (!compare(element, values[0]))166return;167
168values[0] = std::move(element);169heapReplaceTop();170return;171}172
173values.push_back(std::move(element), arena);174std::push_heap(values.begin(), values.end(), Comparator());175}176else177{178values.push_back(std::move(element), arena);179partialSortAndLimitIfNeeded(max_elements, arena);180}181}182
183ALWAYS_INLINE void insertResultInto(IColumn & to, size_t max_elements, Arena * arena)184{185auto & result_array = assert_cast<ColumnArray &>(to);186auto & result_array_offsets = result_array.getOffsets();187
188sortAndLimit(max_elements, arena);189
190result_array_offsets.push_back(result_array_offsets.back() + values.size());191
192if (values.empty())193return;194
195if constexpr (is_value_generic_field)196{197auto & result_array_data = result_array.getData();198for (auto & value : values)199result_array_data.insert(value);200}201else202{203auto & result_array_data = assert_cast<ColumnVector<T> &>(result_array.getData()).getData();204
205size_t result_array_data_insert_begin = result_array_data.size();206result_array_data.resize(result_array_data_insert_begin + values.size());207
208for (size_t i = 0; i < values.size(); ++i)209result_array_data[result_array_data_insert_begin + i] = values[i];210}211}212};213
214template <typename T>215using GroupArraySortedDataHeap = GroupArraySortedData<T, GroupArraySortedStrategy::heap>;216
217template <typename T>218using GroupArraySortedDataSort = GroupArraySortedData<T, GroupArraySortedStrategy::sort>;219
220constexpr UInt64 aggregate_function_group_array_sorted_max_element_size = 0xFFFFFF;221
222template <typename Data, typename T>223class GroupArraySorted final224: public IAggregateFunctionDataHelper<Data, GroupArraySorted<Data, T>>225{
226public:227explicit GroupArraySorted(228const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elements_)229: IAggregateFunctionDataHelper<Data, GroupArraySorted<Data, T>>(230{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))231, max_elements(max_elements_)232, serialization(data_type_->getDefaultSerialization())233{234if (max_elements > aggregate_function_group_array_sorted_max_element_size)235throw Exception(ErrorCodes::BAD_ARGUMENTS,236"Too large limit parameter for groupArraySorted aggregate function, it should not exceed {}",237aggregate_function_group_array_sorted_max_element_size);238}239
240String getName() const override { return "groupArraySorted"; }241
242void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override243{244if constexpr (std::is_same_v<T, Field>)245{246auto row_value = (*columns[0])[row_num];247this->data(place).addElement(std::move(row_value), max_elements, arena);248}249else250{251auto row_value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];252this->data(place).addElement(std::move(row_value), max_elements, arena);253}254}255
256void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override257{258auto & rhs_values = this->data(rhs).values;259for (auto rhs_element : rhs_values)260this->data(place).addElement(std::move(rhs_element), max_elements, arena);261}262
263void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override264{265auto & values = this->data(place).values;266size_t size = values.size();267writeVarUInt(size, buf);268
269if constexpr (std::is_same_v<T, Field>)270{271for (const Field & element : values)272{273if (element.isNull())274{275writeBinary(false, buf);276}277else278{279writeBinary(true, buf);280serialization->serializeBinary(element, buf, {});281}282}283}284else285{286if constexpr (std::endian::native == std::endian::little)287{288buf.write(reinterpret_cast<const char *>(values.data()), size * sizeof(values[0]));289}290else291{292for (const auto & element : values)293writeBinaryLittleEndian(element, buf);294}295}296}297
298void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override299{300size_t size = 0;301readVarUInt(size, buf);302
303if (unlikely(size > max_elements))304throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elements);305
306auto & values = this->data(place).values;307values.resize_exact(size, arena);308
309if constexpr (std::is_same_v<T, Field>)310{311for (Field & element : values)312{313/// We must initialize the Field type since some internal functions (like operator=) use them314new (&element) Field;315bool has_value = false;316readBinary(has_value, buf);317if (has_value)318serialization->deserializeBinary(element, buf, {});319}320}321else322{323if constexpr (std::endian::native == std::endian::little)324{325buf.readStrict(reinterpret_cast<char *>(values.data()), size * sizeof(values[0]));326}327else328{329for (auto & element : values)330readBinaryLittleEndian(element, buf);331}332}333}334
335void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override336{337this->data(place).insertResultInto(to, max_elements, arena);338}339
340bool allocatesMemoryInArena() const override { return true; }341
342private:343UInt64 max_elements;344SerializationPtr serialization;345};346
347template <typename T>348using GroupArraySortedHeap = GroupArraySorted<GroupArraySortedDataHeap<T>, T>;349
350template <typename T>351using GroupArraySortedSort = GroupArraySorted<GroupArraySortedDataSort<T>, T>;352
353template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>354AggregateFunctionPtr createWithNumericOrTimeType(const IDataType & argument_type, TArgs && ... args)355{
356WhichDataType which(argument_type);357
358if (which.idx == TypeIndex::Date) return std::make_shared<AggregateFunctionTemplate<UInt16>>(std::forward<TArgs>(args)...);359if (which.idx == TypeIndex::DateTime) return std::make_shared<AggregateFunctionTemplate<UInt32>>(std::forward<TArgs>(args)...);360if (which.idx == TypeIndex::IPv4) return std::make_shared<AggregateFunctionTemplate<IPv4>>(std::forward<TArgs>(args)...);361
362return AggregateFunctionPtr(createWithNumericType<AggregateFunctionTemplate, TArgs...>(argument_type, std::forward<TArgs>(args)...));363}
364
365template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>366inline AggregateFunctionPtr createAggregateFunctionGroupArraySortedImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args)367{
368if (auto res = createWithNumericOrTimeType<AggregateFunctionTemplate>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))369return AggregateFunctionPtr(res);370
371return std::make_shared<AggregateFunctionTemplate<Field>>(argument_type, parameters, std::forward<TArgs>(args)...);372}
373
374AggregateFunctionPtr createAggregateFunctionGroupArray(375const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)376{
377assertUnary(name, argument_types);378
379UInt64 max_elems = std::numeric_limits<UInt64>::max();380
381if (parameters.empty())382{383throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should have limit argument", name);384}385else if (parameters.size() == 1)386{387auto type = parameters[0].getType();388if (type != Field::Types::Int64 && type != Field::Types::UInt64)389throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be positive number", name);390
391if ((type == Field::Types::Int64 && parameters[0].get<Int64>() < 0) ||392(type == Field::Types::UInt64 && parameters[0].get<UInt64>() == 0))393throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be positive number", name);394
395max_elems = parameters[0].get<UInt64>();396}397else398throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,399"Function {} does not support this number of arguments", name);400
401if (max_elems > group_array_sorted_sort_strategy_max_elements_threshold)402return createAggregateFunctionGroupArraySortedImpl<GroupArraySortedSort>(argument_types[0], parameters, max_elems);403
404return createAggregateFunctionGroupArraySortedImpl<GroupArraySortedHeap>(argument_types[0], parameters, max_elems);405}
406
407}
408
409void registerAggregateFunctionGroupArraySorted(AggregateFunctionFactory & factory)410{
411AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = false };412
413factory.registerFunction("groupArraySorted", { createAggregateFunctionGroupArray, properties });414}
415
416}
417