ClickHouse
543 строки · 19.7 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/Helpers.h>
3#include <AggregateFunctions/FactoryHelpers.h>
4#include <Common/FieldVisitorConvertToNumber.h>
5#include <DataTypes/DataTypeDate.h>
6#include <DataTypes/DataTypeDateTime.h>
7#include <DataTypes/DataTypeIPv4andIPv6.h>
8#include <DataTypes/DataTypesNumber.h>
9
10#include <IO/WriteHelpers.h>
11#include <IO/ReadHelpers.h>
12#include <IO/ReadHelpersArena.h>
13
14#include <DataTypes/DataTypeArray.h>
15#include <DataTypes/DataTypeTuple.h>
16#include <DataTypes/DataTypeString.h>
17
18#include <Columns/ColumnArray.h>
19
20#include <Common/SpaceSaving.h>
21#include <Common/assert_cast.h>
22
23#include <AggregateFunctions/IAggregateFunction.h>
24#include <AggregateFunctions/KeyHolderHelpers.h>
25
26
27namespace DB
28{
29
30struct Settings;
31
32namespace ErrorCodes
33{
34extern const int ARGUMENT_OUT_OF_BOUND;
35extern const int ILLEGAL_TYPE_OF_ARGUMENT;
36extern const int BAD_ARGUMENTS;
37extern const int LOGICAL_ERROR;
38extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
39}
40
41
42namespace
43{
44
45inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;
46
47template <typename T>
48struct AggregateFunctionTopKData
49{
50using Set = SpaceSaving<T, HashCRC32<T>>;
51
52Set value;
53};
54
55
56template <typename T, bool is_weighted>
57class AggregateFunctionTopK
58: public IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>
59{
60protected:
61using State = AggregateFunctionTopKData<T>;
62UInt64 threshold;
63UInt64 reserved;
64bool include_counts;
65bool is_approx_top_k;
66
67public:
68AggregateFunctionTopK(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
69: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, createResultType(argument_types_, include_counts_))
70, threshold(threshold_), reserved(reserved_), include_counts(include_counts_), is_approx_top_k(is_approx_top_k_)
71{}
72
73AggregateFunctionTopK(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params, const DataTypePtr & result_type_)
74: IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, result_type_)
75, threshold(threshold_), reserved(reserved_), include_counts(include_counts_), is_approx_top_k(is_approx_top_k_)
76{}
77
78String getName() const override
79{
80if (is_approx_top_k)
81return is_weighted ? "approx_top_sum" : "approx_top_k";
82else
83return is_weighted ? "topKWeighted" : "topK";
84}
85
86static DataTypePtr createResultType(const DataTypes & argument_types_, bool include_counts_)
87{
88if (include_counts_)
89{
90DataTypes types
91{
92argument_types_[0],
93std::make_shared<DataTypeUInt64>(),
94std::make_shared<DataTypeUInt64>(),
95};
96
97Strings names
98{
99"item",
100"count",
101"error",
102};
103
104return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(
105std::move(types),
106std::move(names)
107));
108}
109else
110return std::make_shared<DataTypeArray>(argument_types_[0]);
111}
112
113bool allocatesMemoryInArena() const override { return false; }
114
115void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
116{
117auto & set = this->data(place).value;
118if (set.capacity() != reserved)
119set.resize(reserved);
120
121if constexpr (is_weighted)
122set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num], columns[1]->getUInt(row_num));
123else
124set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
125}
126
127void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
128{
129auto & set = this->data(place).value;
130if (set.capacity() != reserved)
131set.resize(reserved);
132set.merge(this->data(rhs).value);
133}
134
135void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
136{
137this->data(place).value.write(buf);
138}
139
140void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
141{
142auto & set = this->data(place).value;
143set.resize(reserved);
144set.read(buf);
145}
146
147void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
148{
149ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
150ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
151
152const typename State::Set & set = this->data(place).value;
153auto result_vec = set.topK(threshold);
154size_t size = result_vec.size();
155
156offsets_to.push_back(offsets_to.back() + size);
157
158IColumn & data_to = arr_to.getData();
159
160if (include_counts)
161{
162auto & column_tuple = assert_cast<ColumnTuple &>(data_to);
163
164auto & column_key = assert_cast<ColumnVector<T> &>(column_tuple.getColumn(0)).getData();
165auto & column_count = assert_cast<ColumnVector<UInt64> &>(column_tuple.getColumn(1)).getData();
166auto & column_error = assert_cast<ColumnVector<UInt64> &>(column_tuple.getColumn(2)).getData();
167size_t old_size = column_key.size();
168column_key.resize(old_size + size);
169column_count.resize(old_size + size);
170column_error.resize(old_size + size);
171
172size_t i = 0;
173for (auto it = result_vec.begin(); it != result_vec.end(); ++it, ++i)
174{
175column_key[old_size + i] = it->key;
176column_count[old_size + i] = it->count;
177column_error[old_size + i] = it->error;
178}
179
180} else
181{
182
183auto & column_key = assert_cast<ColumnVector<T> &>(data_to).getData();
184size_t old_size = column_key.size();
185column_key.resize(old_size + size);
186size_t i = 0;
187for (auto it = result_vec.begin(); it != result_vec.end(); ++it, ++i)
188{
189column_key[old_size + i] = it->key;
190}
191}
192}
193};
194
195
196/// Generic implementation, it uses serialized representation as object descriptor.
197struct AggregateFunctionTopKGenericData
198{
199using Set = SpaceSaving<StringRef, StringRefHash>;
200
201Set value;
202};
203
204/** Template parameter with true value should be used for columns that store their elements in memory continuously.
205* For such columns topK() can be implemented more efficiently (especially for small numeric arrays).
206*/
207template <bool is_plain_column, bool is_weighted>
208class AggregateFunctionTopKGeneric
209: public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>
210{
211private:
212using State = AggregateFunctionTopKGenericData;
213
214UInt64 threshold;
215UInt64 reserved;
216bool include_counts;
217bool is_approx_top_k;
218
219public:
220AggregateFunctionTopKGeneric(
221UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
222: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params, createResultType(argument_types_, include_counts_))
223, threshold(threshold_), reserved(reserved_), include_counts(include_counts_), is_approx_top_k(is_approx_top_k_) {}
224
225String getName() const override
226{
227if (is_approx_top_k)
228return is_weighted ? "approx_top_sum" : "approx_top_k";
229else
230return is_weighted ? "topKWeighted" : "topK";
231}
232
233static DataTypePtr createResultType(const DataTypes & argument_types_, bool include_counts_)
234{
235if (include_counts_)
236{
237DataTypes types
238{
239argument_types_[0],
240std::make_shared<DataTypeUInt64>(),
241std::make_shared<DataTypeUInt64>(),
242};
243
244Strings names
245{
246"item",
247"count",
248"error",
249};
250
251return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(
252std::move(types),
253std::move(names)
254));
255
256} else
257{
258return std::make_shared<DataTypeArray>(argument_types_[0]);
259}
260}
261
262bool allocatesMemoryInArena() const override
263{
264return true;
265}
266
267void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
268{
269this->data(place).value.write(buf);
270}
271
272void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
273{
274auto & set = this->data(place).value;
275set.clear();
276
277// Specialized here because there's no deserialiser for StringRef
278size_t size = 0;
279readVarUInt(size, buf);
280if (unlikely(size > TOP_K_MAX_SIZE))
281throw Exception(
282ErrorCodes::ARGUMENT_OUT_OF_BOUND,
283"Too large size ({}) for aggregate function '{}' state (maximum is {})",
284size,
285getName(),
286TOP_K_MAX_SIZE);
287set.resize(size);
288for (size_t i = 0; i < size; ++i)
289{
290auto ref = readStringBinaryInto(*arena, buf);
291UInt64 count;
292UInt64 error;
293readVarUInt(count, buf);
294readVarUInt(error, buf);
295set.insert(ref, count, error);
296arena->rollback(ref.size);
297}
298
299set.readAlphaMap(buf);
300}
301
302void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
303{
304auto & set = this->data(place).value;
305if (set.capacity() != reserved)
306set.resize(reserved);
307
308if constexpr (is_plain_column)
309{
310if constexpr (is_weighted)
311set.insert(columns[0]->getDataAt(row_num), columns[1]->getUInt(row_num));
312else
313set.insert(columns[0]->getDataAt(row_num));
314}
315else
316{
317const char * begin = nullptr;
318StringRef str_serialized = columns[0]->serializeValueIntoArena(row_num, *arena, begin);
319if constexpr (is_weighted)
320set.insert(str_serialized, columns[1]->getUInt(row_num));
321else
322set.insert(str_serialized);
323arena->rollback(str_serialized.size);
324}
325}
326
327void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
328{
329if (!this->data(rhs).value.size())
330return;
331
332auto & set = this->data(place).value;
333if (set.capacity() != reserved)
334set.resize(reserved);
335set.merge(this->data(rhs).value);
336}
337
338void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
339{
340ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
341ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
342
343const typename State::Set & set = this->data(place).value;
344auto result_vec = set.topK(threshold);
345size_t size = result_vec.size();
346offsets_to.push_back(offsets_to.back() + size);
347
348IColumn & data_to = arr_to.getData();
349
350if (include_counts)
351{
352auto & column_tuple = assert_cast<ColumnTuple &>(data_to);
353IColumn & column_key = column_tuple.getColumn(0);
354IColumn & column_count = column_tuple.getColumn(1);
355IColumn & column_error = column_tuple.getColumn(2);
356for (auto &elem : result_vec)
357{
358column_count.insert(elem.count);
359column_error.insert(elem.error);
360deserializeAndInsert<is_plain_column>(elem.key, column_key);
361}
362} else
363{
364for (auto & elem : result_vec)
365{
366deserializeAndInsert<is_plain_column>(elem.key, data_to);
367}
368}
369}
370};
371
372
373/// Substitute return type for Date and DateTime
374template <bool is_weighted>
375class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>
376{
377public:
378using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK;
379
380AggregateFunctionTopKDate(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
381: AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>(
382threshold_,
383reserved_,
384include_counts_,
385is_approx_top_k_,
386argument_types_,
387params)
388{}
389};
390
391template <bool is_weighted>
392class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>
393{
394public:
395using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK;
396
397AggregateFunctionTopKDateTime(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
398: AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>(
399threshold_,
400reserved_,
401include_counts_,
402is_approx_top_k_,
403argument_types_,
404params)
405{}
406};
407
408template <bool is_weighted>
409class AggregateFunctionTopKIPv4 : public AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>
410{
411public:
412using AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>::AggregateFunctionTopK;
413
414AggregateFunctionTopKIPv4(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
415: AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>(
416threshold_,
417reserved_,
418include_counts_,
419is_approx_top_k_,
420argument_types_,
421params)
422{}
423};
424
425
426template <bool is_weighted>
427IAggregateFunction * createWithExtraTypes(const DataTypes & argument_types, UInt64 threshold, UInt64 reserved, bool include_counts, bool is_approx_top_k, const Array & params)
428{
429if (argument_types.empty())
430throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Got empty arguments list");
431
432WhichDataType which(argument_types[0]);
433if (which.idx == TypeIndex::Date)
434return new AggregateFunctionTopKDate<is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
435if (which.idx == TypeIndex::DateTime)
436return new AggregateFunctionTopKDateTime<is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
437if (which.idx == TypeIndex::IPv4)
438return new AggregateFunctionTopKIPv4<is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
439
440/// Check that we can use plain version of AggregateFunctionTopKGeneric
441if (argument_types[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
442return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
443else
444return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
445}
446
447
448template <bool is_weighted, bool is_approx_top_k>
449AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
450{
451if (!is_weighted)
452{
453assertUnary(name, argument_types);
454}
455else
456{
457assertBinary(name, argument_types);
458if (!isInteger(argument_types[1]))
459throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The second argument for aggregate function 'topKWeighted' must have integer type");
460}
461
462UInt64 threshold = 10; /// default values
463UInt64 load_factor = 3;
464bool include_counts = is_approx_top_k;
465UInt64 reserved = threshold * load_factor;
466
467if (!params.empty())
468{
469if (params.size() > 3)
470throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
471"Aggregate function '{}' requires three parameters or less", name);
472
473threshold = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
474
475if (params.size() >= 2)
476{
477if (is_approx_top_k)
478{
479reserved = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
480
481if (reserved < 1)
482throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND,
483"Too small parameter 'reserved' for aggregate function '{}' (got {}, minimum is 1)", name, reserved);
484} else
485{
486load_factor = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
487
488if (load_factor < 1)
489throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND,
490"Too small parameter 'load_factor' for aggregate function '{}' (got {}, minimum is 1)", name, load_factor);
491}
492}
493
494if (params.size() == 3)
495{
496String option = params.at(2).safeGet<String>();
497
498if (option == "counts")
499include_counts = true;
500else
501throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} doesn't support a parameter: {}", name, option);
502
503}
504
505if (!is_approx_top_k)
506{
507reserved = threshold * load_factor;
508}
509
510if (reserved > DB::TOP_K_MAX_SIZE || load_factor > DB::TOP_K_MAX_SIZE || threshold > DB::TOP_K_MAX_SIZE)
511throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND,
512"Too large parameter(s) for aggregate function '{}' (maximum is {})", name, toString(TOP_K_MAX_SIZE));
513
514if (threshold == 0 || reserved == 0)
515throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Parameter 0 is illegal for aggregate function '{}'", name);
516}
517
518AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(
519*argument_types[0], threshold, reserved, include_counts, is_approx_top_k, argument_types, params));
520
521if (!res)
522res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types, threshold, reserved, include_counts, is_approx_top_k, params));
523
524if (!res)
525throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
526"Illegal type {} of argument for aggregate function '{}'", argument_types[0]->getName(), name);
527return res;
528}
529
530}
531
532void registerAggregateFunctionTopK(AggregateFunctionFactory & factory)
533{
534AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true };
535
536factory.registerFunction("topK", { createAggregateFunctionTopK<false, false>, properties });
537factory.registerFunction("topKWeighted", { createAggregateFunctionTopK<true, false>, properties });
538factory.registerFunction("approx_top_k", { createAggregateFunctionTopK<false, true>, properties }, AggregateFunctionFactory::CaseInsensitive);
539factory.registerFunction("approx_top_sum", { createAggregateFunctionTopK<true, true>, properties }, AggregateFunctionFactory::CaseInsensitive);
540factory.registerAlias("approx_top_count", "approx_top_k", AggregateFunctionFactory::CaseInsensitive);
541}
542
543}
544