ClickHouse

Форк
0
/
AggregateFunctionGroupArraySorted.cpp 
416 строк · 13.9 Кб
1
#include <AggregateFunctions/IAggregateFunction.h>
2
#include <AggregateFunctions/AggregateFunctionFactory.h>
3
#include <AggregateFunctions/Helpers.h>
4
#include <AggregateFunctions/FactoryHelpers.h>
5

6
#include <base/sort.h>
7
#include <algorithm>
8
#include <type_traits>
9
#include <utility>
10

11
#include <Common/RadixSort.h>
12
#include <Common/Exception.h>
13
#include <Common/ArenaAllocator.h>
14
#include <Common/assert_cast.h>
15

16
#include <IO/ReadHelpers.h>
17
#include <IO/WriteHelpers.h>
18
#include <IO/ReadBufferFromString.h>
19
#include <IO/WriteBufferFromString.h>
20
#include <IO/Operators.h>
21

22
#include <DataTypes/IDataType.h>
23
#include <DataTypes/DataTypeDate.h>
24
#include <DataTypes/DataTypeDateTime.h>
25
#include <DataTypes/DataTypeArray.h>
26
#include <DataTypes/DataTypeString.h>
27
#include <DataTypes/DataTypesNumber.h>
28
#include <Columns/ColumnArray.h>
29
#include <Columns/ColumnString.h>
30
#include <Columns/ColumnVector.h>
31

32
#include <Columns/IColumn.h>
33
#include <Columns/ColumnConst.h>
34

35
namespace DB
36
{
37

38
struct Settings;
39

40
namespace ErrorCodes
41
{
42
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
43
    extern const int BAD_ARGUMENTS;
44
    extern const int TOO_LARGE_ARRAY_SIZE;
45
}
46

47
namespace
48
{
49

50
enum class GroupArraySortedStrategy
51
{
52
    heap,
53
    sort
54
};
55

56
constexpr size_t group_array_sorted_sort_strategy_max_elements_threshold = 1000000;
57

58
template <typename T, GroupArraySortedStrategy strategy>
59
struct GroupArraySortedData
60
{
61
    using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
62
    using Array = PODArray<T, 32, Allocator>;
63

64
    static constexpr size_t partial_sort_max_elements_factor = 2;
65

66
    static constexpr bool is_value_generic_field = std::is_same_v<T, Field>;
67

68
    Array values;
69

70
    static bool compare(const T & lhs, const T & rhs)
71
    {
72
        if constexpr (is_value_generic_field)
73
        {
74
            return lhs < rhs;
75
        }
76
        else
77
        {
78
            return CompareHelper<T>::less(lhs, rhs, -1);
79
        }
80
    }
81

82
    struct Comparator
83
    {
84
        bool operator()(const T & lhs, const T & rhs)
85
        {
86
            return compare(lhs, rhs);
87
        }
88
    };
89

90
    ALWAYS_INLINE void heapReplaceTop()
91
    {
92
        size_t size = values.size();
93
        if (size < 2)
94
            return;
95

96
        size_t child_index = 1;
97

98
        if (values.size() > 2 && compare(values[1], values[2]))
99
            ++child_index;
100

101
        /// Check if we are in order
102
        if (compare(values[child_index], values[0]))
103
            return;
104

105
        size_t current_index = 0;
106
        auto current = values[current_index];
107

108
        do
109
        {
110
            /// We are not in heap-order, swap the parent with it's largest child.
111
            values[current_index] = values[child_index];
112
            current_index = child_index;
113

114
            // Recompute the child based off of the updated parent
115
            child_index = 2 * child_index + 1;
116

117
            if (child_index >= size)
118
                break;
119

120
            if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1]))
121
            {
122
                /// Right child exists and is greater than left child.
123
                ++child_index;
124
            }
125

126
            /// Check if we are in order.
127
        } while (!compare(values[child_index], current));
128

129
        values[current_index] = current;
130
    }
131

132
    ALWAYS_INLINE void sortAndLimit(size_t max_elements, Arena * arena)
133
    {
134
        if constexpr (is_value_generic_field)
135
        {
136
            ::sort(values.begin(), values.end(), Comparator());
137
        }
138
        else
139
        {
140
            bool try_sort = trySort(values.begin(), values.end(), Comparator());
141
            if (!try_sort)
142
                RadixSort<RadixSortNumTraits<T>>::executeLSD(values.data(), values.size());
143
        }
144

145
        if (values.size() > max_elements)
146
            values.resize(max_elements, arena);
147
    }
148

149
    ALWAYS_INLINE void partialSortAndLimitIfNeeded(size_t max_elements, Arena * arena)
150
    {
151
        if (values.size() < max_elements * partial_sort_max_elements_factor)
152
            return;
153

154
        ::nth_element(values.begin(), values.begin() + max_elements, values.end(), Comparator());
155
        values.resize(max_elements, arena);
156
    }
157

158
    ALWAYS_INLINE void addElement(T && element, size_t max_elements, Arena * arena)
159
    {
160
        if constexpr (strategy == GroupArraySortedStrategy::heap)
161
        {
162
            if (values.size() >= max_elements)
163
            {
164
                /// Element is greater or equal than current max element, it cannot be in k min elements
165
                if (!compare(element, values[0]))
166
                    return;
167

168
                values[0] = std::move(element);
169
                heapReplaceTop();
170
                return;
171
            }
172

173
            values.push_back(std::move(element), arena);
174
            std::push_heap(values.begin(), values.end(), Comparator());
175
        }
176
        else
177
        {
178
            values.push_back(std::move(element), arena);
179
            partialSortAndLimitIfNeeded(max_elements, arena);
180
        }
181
    }
182

183
    ALWAYS_INLINE void insertResultInto(IColumn & to, size_t max_elements, Arena * arena)
184
    {
185
        auto & result_array = assert_cast<ColumnArray &>(to);
186
        auto & result_array_offsets = result_array.getOffsets();
187

188
        sortAndLimit(max_elements, arena);
189

190
        result_array_offsets.push_back(result_array_offsets.back() + values.size());
191

192
        if (values.empty())
193
            return;
194

195
        if constexpr (is_value_generic_field)
196
        {
197
            auto & result_array_data = result_array.getData();
198
            for (auto & value : values)
199
                result_array_data.insert(value);
200
        }
201
        else
202
        {
203
            auto & result_array_data = assert_cast<ColumnVector<T> &>(result_array.getData()).getData();
204

205
            size_t result_array_data_insert_begin = result_array_data.size();
206
            result_array_data.resize(result_array_data_insert_begin + values.size());
207

208
            for (size_t i = 0; i < values.size(); ++i)
209
                result_array_data[result_array_data_insert_begin + i] = values[i];
210
        }
211
    }
212
};
213

214
template <typename T>
215
using GroupArraySortedDataHeap = GroupArraySortedData<T, GroupArraySortedStrategy::heap>;
216

217
template <typename T>
218
using GroupArraySortedDataSort = GroupArraySortedData<T, GroupArraySortedStrategy::sort>;
219

220
constexpr UInt64 aggregate_function_group_array_sorted_max_element_size = 0xFFFFFF;
221

222
template <typename Data, typename T>
223
class GroupArraySorted final
224
    : public IAggregateFunctionDataHelper<Data, GroupArraySorted<Data, T>>
225
{
226
public:
227
    explicit GroupArraySorted(
228
        const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elements_)
229
        : IAggregateFunctionDataHelper<Data, GroupArraySorted<Data, T>>(
230
            {data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
231
        , max_elements(max_elements_)
232
        , serialization(data_type_->getDefaultSerialization())
233
    {
234
        if (max_elements > aggregate_function_group_array_sorted_max_element_size)
235
            throw Exception(ErrorCodes::BAD_ARGUMENTS,
236
                "Too large limit parameter for groupArraySorted aggregate function, it should not exceed {}",
237
                aggregate_function_group_array_sorted_max_element_size);
238
    }
239

240
    String getName() const override { return "groupArraySorted"; }
241

242
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
243
    {
244
        if constexpr (std::is_same_v<T, Field>)
245
        {
246
            auto row_value = (*columns[0])[row_num];
247
            this->data(place).addElement(std::move(row_value), max_elements, arena);
248
        }
249
        else
250
        {
251
            auto row_value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
252
            this->data(place).addElement(std::move(row_value), max_elements, arena);
253
        }
254
    }
255

256
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
257
    {
258
        auto & rhs_values = this->data(rhs).values;
259
        for (auto rhs_element : rhs_values)
260
            this->data(place).addElement(std::move(rhs_element), max_elements, arena);
261
    }
262

263
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
264
    {
265
        auto & values = this->data(place).values;
266
        size_t size = values.size();
267
        writeVarUInt(size, buf);
268

269
        if constexpr (std::is_same_v<T, Field>)
270
        {
271
            for (const Field & element : values)
272
            {
273
                if (element.isNull())
274
                {
275
                    writeBinary(false, buf);
276
                }
277
                else
278
                {
279
                    writeBinary(true, buf);
280
                    serialization->serializeBinary(element, buf, {});
281
                }
282
            }
283
        }
284
        else
285
        {
286
            if constexpr (std::endian::native == std::endian::little)
287
            {
288
                buf.write(reinterpret_cast<const char *>(values.data()), size * sizeof(values[0]));
289
            }
290
            else
291
            {
292
                for (const auto & element : values)
293
                    writeBinaryLittleEndian(element, buf);
294
            }
295
        }
296
    }
297

298
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
299
    {
300
        size_t size = 0;
301
        readVarUInt(size, buf);
302

303
        if (unlikely(size > max_elements))
304
            throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elements);
305

306
        auto & values = this->data(place).values;
307
        values.resize_exact(size, arena);
308

309
        if constexpr (std::is_same_v<T, Field>)
310
        {
311
            for (Field & element : values)
312
            {
313
                /// We must initialize the Field type since some internal functions (like operator=) use them
314
                new (&element) Field;
315
                bool has_value = false;
316
                readBinary(has_value, buf);
317
                if (has_value)
318
                    serialization->deserializeBinary(element, buf, {});
319
            }
320
        }
321
        else
322
        {
323
            if constexpr (std::endian::native == std::endian::little)
324
            {
325
                buf.readStrict(reinterpret_cast<char *>(values.data()), size * sizeof(values[0]));
326
            }
327
            else
328
            {
329
                for (auto & element : values)
330
                    readBinaryLittleEndian(element, buf);
331
            }
332
        }
333
    }
334

335
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
336
    {
337
        this->data(place).insertResultInto(to, max_elements, arena);
338
    }
339

340
    bool allocatesMemoryInArena() const override { return true; }
341

342
private:
343
    UInt64 max_elements;
344
    SerializationPtr serialization;
345
};
346

347
template <typename T>
348
using GroupArraySortedHeap = GroupArraySorted<GroupArraySortedDataHeap<T>, T>;
349

350
template <typename T>
351
using GroupArraySortedSort = GroupArraySorted<GroupArraySortedDataSort<T>, T>;
352

353
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
354
AggregateFunctionPtr createWithNumericOrTimeType(const IDataType & argument_type, TArgs && ... args)
355
{
356
    WhichDataType which(argument_type);
357

358
    if (which.idx == TypeIndex::Date) return std::make_shared<AggregateFunctionTemplate<UInt16>>(std::forward<TArgs>(args)...);
359
    if (which.idx == TypeIndex::DateTime) return std::make_shared<AggregateFunctionTemplate<UInt32>>(std::forward<TArgs>(args)...);
360
    if (which.idx == TypeIndex::IPv4) return std::make_shared<AggregateFunctionTemplate<IPv4>>(std::forward<TArgs>(args)...);
361

362
    return AggregateFunctionPtr(createWithNumericType<AggregateFunctionTemplate, TArgs...>(argument_type, std::forward<TArgs>(args)...));
363
}
364

365
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
366
inline AggregateFunctionPtr createAggregateFunctionGroupArraySortedImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args)
367
{
368
    if (auto res = createWithNumericOrTimeType<AggregateFunctionTemplate>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
369
        return AggregateFunctionPtr(res);
370

371
    return std::make_shared<AggregateFunctionTemplate<Field>>(argument_type, parameters, std::forward<TArgs>(args)...);
372
}
373

374
AggregateFunctionPtr createAggregateFunctionGroupArray(
375
    const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
376
{
377
    assertUnary(name, argument_types);
378

379
    UInt64 max_elems = std::numeric_limits<UInt64>::max();
380

381
    if (parameters.empty())
382
    {
383
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should have limit argument", name);
384
    }
385
    else if (parameters.size() == 1)
386
    {
387
        auto type = parameters[0].getType();
388
        if (type != Field::Types::Int64 && type != Field::Types::UInt64)
389
               throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be positive number", name);
390

391
        if ((type == Field::Types::Int64 && parameters[0].get<Int64>() < 0) ||
392
            (type == Field::Types::UInt64 && parameters[0].get<UInt64>() == 0))
393
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be positive number", name);
394

395
        max_elems = parameters[0].get<UInt64>();
396
    }
397
    else
398
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
399
            "Function {} does not support this number of arguments", name);
400

401
    if (max_elems > group_array_sorted_sort_strategy_max_elements_threshold)
402
        return createAggregateFunctionGroupArraySortedImpl<GroupArraySortedSort>(argument_types[0], parameters, max_elems);
403

404
    return createAggregateFunctionGroupArraySortedImpl<GroupArraySortedHeap>(argument_types[0], parameters, max_elems);
405
}
406

407
}
408

409
void registerAggregateFunctionGroupArraySorted(AggregateFunctionFactory & factory)
410
{
411
    AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = false };
412

413
    factory.registerFunction("groupArraySorted", { createAggregateFunctionGroupArray, properties });
414
}
415

416
}
417

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.