ClickHouse

Форк
0
/
AggregateFunctionTopK.cpp 
543 строки · 19.7 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/Helpers.h>
3
#include <AggregateFunctions/FactoryHelpers.h>
4
#include <Common/FieldVisitorConvertToNumber.h>
5
#include <DataTypes/DataTypeDate.h>
6
#include <DataTypes/DataTypeDateTime.h>
7
#include <DataTypes/DataTypeIPv4andIPv6.h>
8
#include <DataTypes/DataTypesNumber.h>
9

10
#include <IO/WriteHelpers.h>
11
#include <IO/ReadHelpers.h>
12
#include <IO/ReadHelpersArena.h>
13

14
#include <DataTypes/DataTypeArray.h>
15
#include <DataTypes/DataTypeTuple.h>
16
#include <DataTypes/DataTypeString.h>
17

18
#include <Columns/ColumnArray.h>
19

20
#include <Common/SpaceSaving.h>
21
#include <Common/assert_cast.h>
22

23
#include <AggregateFunctions/IAggregateFunction.h>
24
#include <AggregateFunctions/KeyHolderHelpers.h>
25

26

27
namespace DB
28
{
29

30
struct Settings;
31

32
namespace ErrorCodes
33
{
34
    extern const int ARGUMENT_OUT_OF_BOUND;
35
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
36
    extern const int BAD_ARGUMENTS;
37
    extern const int LOGICAL_ERROR;
38
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
39
}
40

41

42
namespace
43
{
44

45
inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;
46

47
template <typename T>
48
struct AggregateFunctionTopKData
49
{
50
    using Set = SpaceSaving<T, HashCRC32<T>>;
51

52
    Set value;
53
};
54

55

56
template <typename T, bool is_weighted>
57
class AggregateFunctionTopK
58
    : public IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>
59
{
60
protected:
61
    using State = AggregateFunctionTopKData<T>;
62
    UInt64 threshold;
63
    UInt64 reserved;
64
    bool include_counts;
65
    bool is_approx_top_k;
66

67
public:
68
    AggregateFunctionTopK(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
69
        : IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, createResultType(argument_types_, include_counts_))
70
        , threshold(threshold_), reserved(reserved_), include_counts(include_counts_), is_approx_top_k(is_approx_top_k_)
71
    {}
72

73
        AggregateFunctionTopK(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params, const DataTypePtr & result_type_)
74
        : IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>(argument_types_, params, result_type_)
75
        , threshold(threshold_), reserved(reserved_), include_counts(include_counts_), is_approx_top_k(is_approx_top_k_)
76
    {}
77

78
    String getName() const override
79
    {
80
        if (is_approx_top_k)
81
            return  is_weighted ? "approx_top_sum" : "approx_top_k";
82
        else
83
            return  is_weighted ? "topKWeighted" : "topK";
84
    }
85

86
    static DataTypePtr createResultType(const DataTypes & argument_types_, bool include_counts_)
87
    {
88
        if (include_counts_)
89
        {
90
            DataTypes types
91
            {
92
                argument_types_[0],
93
                std::make_shared<DataTypeUInt64>(),
94
                std::make_shared<DataTypeUInt64>(),
95
            };
96

97
            Strings names
98
            {
99
                "item",
100
                "count",
101
                "error",
102
            };
103

104
            return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(
105
                std::move(types),
106
                std::move(names)
107
            ));
108
        }
109
        else
110
            return std::make_shared<DataTypeArray>(argument_types_[0]);
111
    }
112

113
    bool allocatesMemoryInArena() const override { return false; }
114

115
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
116
    {
117
        auto & set = this->data(place).value;
118
        if (set.capacity() != reserved)
119
            set.resize(reserved);
120

121
        if constexpr (is_weighted)
122
            set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num], columns[1]->getUInt(row_num));
123
        else
124
            set.insert(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
125
    }
126

127
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
128
    {
129
        auto & set = this->data(place).value;
130
        if (set.capacity() != reserved)
131
            set.resize(reserved);
132
        set.merge(this->data(rhs).value);
133
    }
134

135
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
136
    {
137
        this->data(place).value.write(buf);
138
    }
139

140
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version  */, Arena *) const override
141
    {
142
        auto & set = this->data(place).value;
143
        set.resize(reserved);
144
        set.read(buf);
145
    }
146

147
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
148
    {
149
        ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
150
        ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
151

152
        const typename State::Set & set = this->data(place).value;
153
        auto result_vec = set.topK(threshold);
154
        size_t size = result_vec.size();
155

156
        offsets_to.push_back(offsets_to.back() + size);
157

158
        IColumn & data_to = arr_to.getData();
159

160
        if (include_counts)
161
        {
162
            auto & column_tuple = assert_cast<ColumnTuple &>(data_to);
163

164
            auto & column_key = assert_cast<ColumnVector<T> &>(column_tuple.getColumn(0)).getData();
165
            auto & column_count = assert_cast<ColumnVector<UInt64> &>(column_tuple.getColumn(1)).getData();
166
            auto & column_error = assert_cast<ColumnVector<UInt64> &>(column_tuple.getColumn(2)).getData();
167
            size_t old_size = column_key.size();
168
            column_key.resize(old_size + size);
169
            column_count.resize(old_size + size);
170
            column_error.resize(old_size + size);
171

172
            size_t i = 0;
173
            for (auto it = result_vec.begin(); it != result_vec.end(); ++it, ++i)
174
            {
175
                column_key[old_size + i] = it->key;
176
                column_count[old_size + i] = it->count;
177
                column_error[old_size + i] = it->error;
178
            }
179

180
        } else
181
        {
182

183
            auto & column_key = assert_cast<ColumnVector<T> &>(data_to).getData();
184
            size_t old_size = column_key.size();
185
            column_key.resize(old_size + size);
186
            size_t i = 0;
187
            for (auto it = result_vec.begin(); it != result_vec.end(); ++it, ++i)
188
            {
189
                column_key[old_size + i] = it->key;
190
            }
191
        }
192
    }
193
};
194

195

196
/// Generic implementation, it uses serialized representation as object descriptor.
197
struct AggregateFunctionTopKGenericData
198
{
199
    using Set = SpaceSaving<StringRef, StringRefHash>;
200

201
    Set value;
202
};
203

204
/** Template parameter with true value should be used for columns that store their elements in memory continuously.
205
 *  For such columns topK() can be implemented more efficiently (especially for small numeric arrays).
206
 */
207
template <bool is_plain_column, bool is_weighted>
208
class AggregateFunctionTopKGeneric
209
    : public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>
210
{
211
private:
212
    using State = AggregateFunctionTopKGenericData;
213

214
    UInt64 threshold;
215
    UInt64 reserved;
216
    bool include_counts;
217
    bool is_approx_top_k;
218

219
public:
220
    AggregateFunctionTopKGeneric(
221
        UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
222
        : IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>(argument_types_, params, createResultType(argument_types_, include_counts_))
223
        , threshold(threshold_), reserved(reserved_), include_counts(include_counts_), is_approx_top_k(is_approx_top_k_) {}
224

225
    String getName() const override
226
    {
227
        if (is_approx_top_k)
228
            return  is_weighted ? "approx_top_sum" : "approx_top_k";
229
        else
230
            return  is_weighted ? "topKWeighted" : "topK";
231
    }
232

233
    static DataTypePtr createResultType(const DataTypes & argument_types_, bool include_counts_)
234
    {
235
        if (include_counts_)
236
        {
237
            DataTypes types
238
            {
239
                argument_types_[0],
240
                std::make_shared<DataTypeUInt64>(),
241
                std::make_shared<DataTypeUInt64>(),
242
            };
243

244
            Strings names
245
            {
246
                "item",
247
                "count",
248
                "error",
249
            };
250

251
            return std::make_shared<DataTypeArray>(std::make_shared<DataTypeTuple>(
252
                std::move(types),
253
                std::move(names)
254
            ));
255

256
        } else
257
        {
258
            return std::make_shared<DataTypeArray>(argument_types_[0]);
259
        }
260
    }
261

262
    bool allocatesMemoryInArena() const override
263
    {
264
        return true;
265
    }
266

267
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
268
    {
269
        this->data(place).value.write(buf);
270
    }
271

272
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
273
    {
274
        auto & set = this->data(place).value;
275
        set.clear();
276

277
        // Specialized here because there's no deserialiser for StringRef
278
        size_t size = 0;
279
        readVarUInt(size, buf);
280
        if (unlikely(size > TOP_K_MAX_SIZE))
281
            throw Exception(
282
                ErrorCodes::ARGUMENT_OUT_OF_BOUND,
283
                "Too large size ({}) for aggregate function '{}' state (maximum is {})",
284
                size,
285
                getName(),
286
                TOP_K_MAX_SIZE);
287
        set.resize(size);
288
        for (size_t i = 0; i < size; ++i)
289
        {
290
            auto ref = readStringBinaryInto(*arena, buf);
291
            UInt64 count;
292
            UInt64 error;
293
            readVarUInt(count, buf);
294
            readVarUInt(error, buf);
295
            set.insert(ref, count, error);
296
            arena->rollback(ref.size);
297
        }
298

299
        set.readAlphaMap(buf);
300
    }
301

302
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
303
    {
304
        auto & set = this->data(place).value;
305
        if (set.capacity() != reserved)
306
            set.resize(reserved);
307

308
        if constexpr (is_plain_column)
309
        {
310
            if constexpr (is_weighted)
311
                set.insert(columns[0]->getDataAt(row_num), columns[1]->getUInt(row_num));
312
            else
313
                set.insert(columns[0]->getDataAt(row_num));
314
        }
315
        else
316
        {
317
            const char * begin = nullptr;
318
            StringRef str_serialized = columns[0]->serializeValueIntoArena(row_num, *arena, begin);
319
            if constexpr (is_weighted)
320
                set.insert(str_serialized, columns[1]->getUInt(row_num));
321
            else
322
                set.insert(str_serialized);
323
            arena->rollback(str_serialized.size);
324
        }
325
    }
326

327
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
328
    {
329
        if (!this->data(rhs).value.size())
330
            return;
331

332
        auto & set = this->data(place).value;
333
        if (set.capacity() != reserved)
334
            set.resize(reserved);
335
        set.merge(this->data(rhs).value);
336
    }
337

338
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
339
    {
340
        ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
341
        ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
342

343
        const typename State::Set & set = this->data(place).value;
344
        auto result_vec = set.topK(threshold);
345
        size_t size = result_vec.size();
346
        offsets_to.push_back(offsets_to.back() + size);
347

348
        IColumn & data_to = arr_to.getData();
349

350
        if (include_counts)
351
        {
352
            auto & column_tuple = assert_cast<ColumnTuple &>(data_to);
353
            IColumn & column_key = column_tuple.getColumn(0);
354
            IColumn & column_count = column_tuple.getColumn(1);
355
            IColumn & column_error = column_tuple.getColumn(2);
356
            for (auto &elem : result_vec)
357
            {
358
                column_count.insert(elem.count);
359
                column_error.insert(elem.error);
360
                deserializeAndInsert<is_plain_column>(elem.key, column_key);
361
            }
362
        } else
363
        {
364
            for (auto & elem : result_vec)
365
            {
366
                deserializeAndInsert<is_plain_column>(elem.key, data_to);
367
            }
368
        }
369
    }
370
};
371

372

373
/// Substitute return type for Date and DateTime
374
template <bool is_weighted>
375
class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>
376
{
377
public:
378
    using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK;
379

380
    AggregateFunctionTopKDate(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
381
        : AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>(
382
            threshold_,
383
            reserved_,
384
            include_counts_,
385
            is_approx_top_k_,
386
            argument_types_,
387
            params)
388
    {}
389
};
390

391
template <bool is_weighted>
392
class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>
393
{
394
public:
395
    using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK;
396

397
    AggregateFunctionTopKDateTime(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
398
        : AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>(
399
            threshold_,
400
            reserved_,
401
            include_counts_,
402
            is_approx_top_k_,
403
            argument_types_,
404
            params)
405
    {}
406
};
407

408
template <bool is_weighted>
409
class AggregateFunctionTopKIPv4 : public AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>
410
{
411
public:
412
    using AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>::AggregateFunctionTopK;
413

414
    AggregateFunctionTopKIPv4(UInt64 threshold_, UInt64 reserved_, bool include_counts_, bool is_approx_top_k_, const DataTypes & argument_types_, const Array & params)
415
        : AggregateFunctionTopK<DataTypeIPv4::FieldType, is_weighted>(
416
            threshold_,
417
            reserved_,
418
            include_counts_,
419
            is_approx_top_k_,
420
            argument_types_,
421
            params)
422
    {}
423
};
424

425

426
template <bool is_weighted>
427
IAggregateFunction * createWithExtraTypes(const DataTypes & argument_types, UInt64 threshold, UInt64 reserved, bool include_counts, bool is_approx_top_k, const Array & params)
428
{
429
    if (argument_types.empty())
430
        throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Got empty arguments list");
431

432
    WhichDataType which(argument_types[0]);
433
    if (which.idx == TypeIndex::Date)
434
        return new AggregateFunctionTopKDate<is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
435
    if (which.idx == TypeIndex::DateTime)
436
        return new AggregateFunctionTopKDateTime<is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
437
    if (which.idx == TypeIndex::IPv4)
438
        return new AggregateFunctionTopKIPv4<is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
439

440
    /// Check that we can use plain version of AggregateFunctionTopKGeneric
441
    if (argument_types[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
442
        return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
443
    else
444
        return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, reserved, include_counts, is_approx_top_k, argument_types, params);
445
}
446

447

448
template <bool is_weighted, bool is_approx_top_k>
449
AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
450
{
451
    if (!is_weighted)
452
    {
453
        assertUnary(name, argument_types);
454
    }
455
    else
456
    {
457
        assertBinary(name, argument_types);
458
        if (!isInteger(argument_types[1]))
459
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The second argument for aggregate function 'topKWeighted' must have integer type");
460
    }
461

462
    UInt64 threshold = 10;  /// default values
463
    UInt64 load_factor = 3;
464
    bool include_counts = is_approx_top_k;
465
    UInt64 reserved = threshold * load_factor;
466

467
    if (!params.empty())
468
    {
469
        if (params.size() > 3)
470
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
471
                            "Aggregate function '{}' requires three parameters or less", name);
472

473
        threshold = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
474

475
        if (params.size() >= 2)
476
        {
477
            if (is_approx_top_k)
478
            {
479
                reserved = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
480

481
                if (reserved < 1)
482
                    throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND,
483
                                    "Too small parameter 'reserved' for aggregate function '{}' (got {}, minimum is 1)", name, reserved);
484
            } else
485
            {
486
                load_factor = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
487

488
                if (load_factor < 1)
489
                    throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND,
490
                                    "Too small parameter 'load_factor' for aggregate function '{}' (got {}, minimum is 1)", name, load_factor);
491
            }
492
        }
493

494
        if (params.size() == 3)
495
        {
496
            String option = params.at(2).safeGet<String>();
497

498
            if (option == "counts")
499
                include_counts = true;
500
            else
501
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} doesn't support a parameter: {}", name, option);
502

503
        }
504

505
        if (!is_approx_top_k)
506
        {
507
            reserved = threshold * load_factor;
508
        }
509

510
        if (reserved > DB::TOP_K_MAX_SIZE || load_factor > DB::TOP_K_MAX_SIZE || threshold > DB::TOP_K_MAX_SIZE)
511
            throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND,
512
                            "Too large parameter(s) for aggregate function '{}' (maximum is {})", name, toString(TOP_K_MAX_SIZE));
513

514
        if (threshold == 0 || reserved == 0)
515
            throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Parameter 0 is illegal for aggregate function '{}'", name);
516
    }
517

518
    AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(
519
        *argument_types[0], threshold, reserved, include_counts, is_approx_top_k, argument_types, params));
520

521
    if (!res)
522
        res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types, threshold, reserved, include_counts, is_approx_top_k, params));
523

524
    if (!res)
525
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
526
                        "Illegal type {} of argument for aggregate function '{}'", argument_types[0]->getName(), name);
527
    return res;
528
}
529

530
}
531

532
void registerAggregateFunctionTopK(AggregateFunctionFactory & factory)
533
{
534
    AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true };
535

536
    factory.registerFunction("topK", { createAggregateFunctionTopK<false, false>, properties });
537
    factory.registerFunction("topKWeighted", { createAggregateFunctionTopK<true, false>, properties });
538
    factory.registerFunction("approx_top_k", { createAggregateFunctionTopK<false, true>, properties }, AggregateFunctionFactory::CaseInsensitive);
539
    factory.registerFunction("approx_top_sum", { createAggregateFunctionTopK<true, true>, properties }, AggregateFunctionFactory::CaseInsensitive);
540
    factory.registerAlias("approx_top_count", "approx_top_k", AggregateFunctionFactory::CaseInsensitive);
541
}
542

543
}
544

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.