ClickHouse

Форк
0
/
AggregateFunctionSumMap.cpp 
797 строк · 29.1 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <Functions/FunctionHelpers.h>
3

4
#include <IO/ReadHelpers.h>
5

6
#include <DataTypes/DataTypeArray.h>
7
#include <DataTypes/DataTypeTuple.h>
8
#include <DataTypes/DataTypeNullable.h>
9

10
#include <Columns/ColumnArray.h>
11
#include <Columns/ColumnTuple.h>
12
#include <Columns/ColumnString.h>
13

14
#include <Common/FieldVisitorSum.h>
15
#include <Common/assert_cast.h>
16
#include <AggregateFunctions/IAggregateFunction.h>
17
#include <AggregateFunctions/FactoryHelpers.h>
18
#include <map>
19

20

21
namespace DB
22
{
23

24
struct Settings;
25

26
namespace ErrorCodes
27
{
28
    extern const int BAD_ARGUMENTS;
29
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
30
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
31
    extern const int LOGICAL_ERROR;
32
}
33

34
namespace
35
{
36

37
struct AggregateFunctionMapData
38
{
39
    // Map needs to be ordered to maintain function properties
40
    std::map<Field, Array> merged_maps;
41
};
42

43
/** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of at least 2 arrays -
44
  * ordered keys and variable number of argument values aggregated by corresponding keys.
45
  *
46
  * sumMap function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map".
47
  *
48
  * Example: sumMap(k, v...) of:
49
  *  k           v
50
  *  [1,2,3]     [10,10,10]
51
  *  [3,4,5]     [10,10,10]
52
  *  [4,5,6]     [10,10,10]
53
  *  [6,7,8]     [10,10,10]
54
  *  [7,5,3]     [5,15,25]
55
  *  [8,9,10]    [20,20,20]
56
  * will return:
57
  *  ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
58
  *
59
  * minMap and maxMap share the same idea, but calculate min and max correspondingly.
60
  *
61
  * NOTE: The implementation of these functions are "amateur grade" - not efficient and low quality.
62
  */
63

64
template <typename Derived, typename Visitor, bool overflow, bool tuple_argument, bool compact>
65
class AggregateFunctionMapBase : public IAggregateFunctionDataHelper<
66
    AggregateFunctionMapData, Derived>
67
{
68
private:
69
    static constexpr auto STATE_VERSION_1_MIN_REVISION = 54452;
70

71
    DataTypePtr keys_type;
72
    SerializationPtr keys_serialization;
73
    DataTypes values_types;
74
    Serializations values_serializations;
75
    Serializations promoted_values_serializations;
76

77
public:
78
    using Base = IAggregateFunctionDataHelper<AggregateFunctionMapData, Derived>;
79

80
    AggregateFunctionMapBase(const DataTypePtr & keys_type_,
81
            const DataTypes & values_types_, const DataTypes & argument_types_)
82
        : Base(argument_types_, {} /* parameters */, createResultType(keys_type_, values_types_))
83
        , keys_type(keys_type_)
84
        , keys_serialization(keys_type->getDefaultSerialization())
85
        , values_types(values_types_)
86
    {
87
        values_serializations.reserve(values_types.size());
88
        promoted_values_serializations.reserve(values_types.size());
89
        for (const auto & type : values_types)
90
        {
91
            values_serializations.emplace_back(type->getDefaultSerialization());
92
            if (type->canBePromoted())
93
            {
94
                if (type->isNullable())
95
                    promoted_values_serializations.emplace_back(
96
                         makeNullable(removeNullable(type)->promoteNumericType())->getDefaultSerialization());
97
                else
98
                    promoted_values_serializations.emplace_back(type->promoteNumericType()->getDefaultSerialization());
99
            }
100
            else
101
            {
102
                promoted_values_serializations.emplace_back(type->getDefaultSerialization());
103
            }
104
        }
105
    }
106

107
    bool isVersioned() const override { return true; }
108

109
    size_t getDefaultVersion() const override { return 1; }
110

111
    size_t getVersionFromRevision(size_t revision) const override
112
    {
113
        if (revision >= STATE_VERSION_1_MIN_REVISION)
114
            return 1;
115
        else
116
            return 0;
117
    }
118

119
    static DataTypePtr createResultType(
120
        const DataTypePtr & keys_type_,
121
        const DataTypes & values_types_)
122
    {
123
        DataTypes types;
124
        types.emplace_back(std::make_shared<DataTypeArray>(keys_type_));
125

126
        for (const auto & value_type : values_types_)
127
        {
128
            if constexpr (std::is_same_v<Visitor, FieldVisitorSum>)
129
            {
130
                if (!value_type->isSummable())
131
                    throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
132
                        "Values for -Map cannot be summed, passed type {}",
133
                        value_type->getName()};
134
            }
135

136
            DataTypePtr result_type;
137

138
            if constexpr (overflow)
139
            {
140
                if (value_type->onlyNull())
141
                    throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
142
                        "Cannot calculate -Map of type {}",
143
                        value_type->getName()};
144

145
                // Overflow, meaning that the returned type is the same as
146
                // the input type. Nulls are skipped.
147
                result_type = removeNullable(value_type);
148
            }
149
            else
150
            {
151
                auto value_type_without_nullable = removeNullable(value_type);
152

153
                // No overflow, meaning we promote the types if necessary.
154
                if (!value_type_without_nullable->canBePromoted())
155
                    throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
156
                        "Values for -Map are expected to be Numeric, Float or Decimal, passed type {}",
157
                        value_type->getName()};
158

159
                WhichDataType value_type_to_check(value_type_without_nullable);
160

161
                /// Do not promote decimal because of implementation issues of this function design
162
                /// Currently we cannot get result column type in case of decimal we cannot get decimal scale
163
                /// in method void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
164
                /// If we decide to make this function more efficient we should promote decimal type during summ
165
                if (value_type_to_check.isDecimal())
166
                    result_type = value_type_without_nullable;
167
                else
168
                    result_type = value_type_without_nullable->promoteNumericType();
169
            }
170

171
            types.emplace_back(std::make_shared<DataTypeArray>(result_type));
172
        }
173

174
        return std::make_shared<DataTypeTuple>(types);
175
    }
176

177
    bool allocatesMemoryInArena() const override { return false; }
178

179
    static auto getArgumentColumns(const IColumn ** columns)
180
    {
181
        if constexpr (tuple_argument)
182
        {
183
            return assert_cast<const ColumnTuple *>(columns[0])->getColumns();
184
        }
185
        else
186
        {
187
            return columns;
188
        }
189
    }
190

191
    void add(AggregateDataPtr __restrict place, const IColumn ** columns_, const size_t row_num, Arena *) const override
192
    {
193
        const auto & columns = getArgumentColumns(columns_);
194

195
        // Column 0 contains array of keys of known type
196
        const ColumnArray & array_column0 = assert_cast<const ColumnArray &>(*columns[0]);
197
        const IColumn::Offsets & offsets0 = array_column0.getOffsets();
198
        const IColumn & key_column = array_column0.getData();
199
        const size_t keys_vec_offset = offsets0[row_num - 1];
200
        const size_t keys_vec_size = (offsets0[row_num] - keys_vec_offset);
201

202
        // Columns 1..n contain arrays of numeric values to sum
203
        auto & merged_maps = this->data(place).merged_maps;
204
        for (size_t col = 0, size = values_types.size(); col < size; ++col)
205
        {
206
            const auto & array_column = assert_cast<const ColumnArray &>(*columns[col + 1]);
207
            const IColumn & value_column = array_column.getData();
208
            const IColumn::Offsets & offsets = array_column.getOffsets();
209
            const size_t values_vec_offset = offsets[row_num - 1];
210
            const size_t values_vec_size = (offsets[row_num] - values_vec_offset);
211

212
            // Expect key and value arrays to be of same length
213
            if (keys_vec_size != values_vec_size)
214
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Sizes of keys and values arrays do not match");
215

216
            // Insert column values for all keys
217
            for (size_t i = 0; i < keys_vec_size; ++i)
218
            {
219
                Field value = value_column[values_vec_offset + i];
220
                Field key = key_column[keys_vec_offset + i];
221

222
                if (!keepKey(key))
223
                    continue;
224

225
                auto [it, inserted] = merged_maps.emplace(key, Array());
226

227
                if (inserted)
228
                {
229
                    it->second.resize(size);
230
                    it->second[col] = value;
231
                }
232
                else
233
                {
234
                    if (!value.isNull())
235
                    {
236
                        if (it->second[col].isNull())
237
                            it->second[col] = value;
238
                        else
239
                            applyVisitor(Visitor(value), it->second[col]);
240
                    }
241
                }
242
            }
243
        }
244
    }
245

246
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
247
    {
248
        auto & merged_maps = this->data(place).merged_maps;
249
        const auto & rhs_maps = this->data(rhs).merged_maps;
250

251
        for (const auto & elem : rhs_maps)
252
        {
253
            const auto & it = merged_maps.find(elem.first);
254
            if (it != merged_maps.end())
255
            {
256
                for (size_t col = 0; col < values_types.size(); ++col)
257
                {
258
                    if (!elem.second[col].isNull())
259
                    {
260
                        if (it->second[col].isNull())
261
                            it->second[col] = elem.second[col];
262
                        else
263
                            applyVisitor(Visitor(elem.second[col]), it->second[col]);
264
                    }
265
                }
266
            }
267
            else
268
            {
269
                merged_maps[elem.first] = elem.second;
270
            }
271
        }
272
    }
273

274
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
275
    {
276
        if (!version)
277
            version = getDefaultVersion();
278

279
        const auto & merged_maps = this->data(place).merged_maps;
280
        size_t size = merged_maps.size();
281
        writeVarUInt(size, buf);
282

283
        std::function<void(size_t, const Array &)> serialize;
284
        switch (*version)
285
        {
286
            case 0:
287
            {
288
                serialize = [&](size_t col_idx, const Array & values)
289
                {
290
                    values_serializations[col_idx]->serializeBinary(values[col_idx], buf, {});
291
                };
292
                break;
293
            }
294
            case 1:
295
            {
296
                serialize = [&](size_t col_idx, const Array & values)
297
                {
298
                    Field value = values[col_idx];
299

300
                    /// Compatibility with previous versions.
301
                    if (value.getType() == Field::Types::Decimal32)
302
                    {
303
                        auto source = value.get<DecimalField<Decimal32>>();
304
                        value = DecimalField<Decimal128>(source.getValue(), source.getScale());
305
                    }
306
                    else if (value.getType() == Field::Types::Decimal64)
307
                    {
308
                        auto source = value.get<DecimalField<Decimal64>>();
309
                        value = DecimalField<Decimal128>(source.getValue(), source.getScale());
310
                    }
311

312
                    promoted_values_serializations[col_idx]->serializeBinary(value, buf, {});
313
                };
314
                break;
315
            }
316
            default:
317
                throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown version {}, of -Map aggregate function serialization state", *version);
318
        }
319

320
        for (const auto & elem : merged_maps)
321
        {
322
            keys_serialization->serializeBinary(elem.first, buf, {});
323
            for (size_t col = 0; col < values_types.size(); ++col)
324
                serialize(col, elem.second);
325
        }
326
    }
327

328
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena *) const override
329
    {
330
        if (!version)
331
            version = getDefaultVersion();
332

333
        auto & merged_maps = this->data(place).merged_maps;
334
        size_t size = 0;
335
        readVarUInt(size, buf);
336

337
        std::function<void(size_t, Array &)> deserialize;
338
        switch (*version)
339
        {
340
            case 0:
341
            {
342
                deserialize = [&](size_t col_idx, Array & values)
343
                {
344
                    values_serializations[col_idx]->deserializeBinary(values[col_idx], buf, {});
345
                };
346
                break;
347
            }
348
            case 1:
349
            {
350
                deserialize = [&](size_t col_idx, Array & values)
351
                {
352
                    Field & value = values[col_idx];
353
                    promoted_values_serializations[col_idx]->deserializeBinary(value, buf, {});
354

355
                    /// Compatibility with previous versions.
356
                    if (value.getType() == Field::Types::Decimal128)
357
                    {
358
                        auto source = value.get<DecimalField<Decimal128>>();
359
                        WhichDataType value_type(values_types[col_idx]);
360
                        if (value_type.isDecimal32())
361
                        {
362
                            value = DecimalField<Decimal32>(source.getValue(), source.getScale());
363
                        }
364
                        else if (value_type.isDecimal64())
365
                        {
366
                            value = DecimalField<Decimal64>(source.getValue(), source.getScale());
367
                        }
368
                    }
369
                };
370
                break;
371
            }
372
            default:
373
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unexpected version {} of -Map aggregate function serialization state", *version);
374
        }
375

376
        for (size_t i = 0; i < size; ++i)
377
        {
378
            Field key;
379
            keys_serialization->deserializeBinary(key, buf, {});
380

381
            Array values;
382
            values.resize(values_types.size());
383

384
            for (size_t col = 0; col < values_types.size(); ++col)
385
                deserialize(col, values);
386

387
            merged_maps[key] = values;
388
        }
389
    }
390

391
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
392
    {
393
        size_t num_columns = values_types.size();
394

395
        // Final step does compaction of keys that have zero values, this mutates the state
396
        auto & merged_maps = this->data(place).merged_maps;
397

398
        // Remove keys which are zeros or empty. This should be enabled only for sumMap.
399
        if constexpr (compact)
400
        {
401
            for (auto it = merged_maps.cbegin(); it != merged_maps.cend();)
402
            {
403
                // Key is not compacted if it has at least one non-zero value
404
                bool erase = true;
405
                for (size_t col = 0; col < num_columns; ++col)
406
                {
407
                    if (!it->second[col].isNull() && it->second[col] != values_types[col]->getDefault())
408
                    {
409
                        erase = false;
410
                        break;
411
                    }
412
                }
413

414
                if (erase)
415
                    it = merged_maps.erase(it);
416
                else
417
                    ++it;
418
            }
419
        }
420

421
        size_t size = merged_maps.size();
422

423
        auto & to_tuple = assert_cast<ColumnTuple &>(to);
424
        auto & to_keys_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(0));
425
        auto & to_keys_col = to_keys_arr.getData();
426

427
        // Advance column offsets
428
        auto & to_keys_offsets = to_keys_arr.getOffsets();
429
        to_keys_offsets.push_back(to_keys_offsets.back() + size);
430
        to_keys_col.reserve(size);
431

432
        for (size_t col = 0; col < num_columns; ++col)
433
        {
434
            auto & to_values_arr = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1));
435
            auto & to_values_offsets = to_values_arr.getOffsets();
436
            to_values_offsets.push_back(to_values_offsets.back() + size);
437
            to_values_arr.getData().reserve(size);
438
        }
439

440
        // Write arrays of keys and values
441
        for (const auto & elem : merged_maps)
442
        {
443
            // Write array of keys into column
444
            to_keys_col.insert(elem.first);
445

446
            // Write 0..n arrays of values
447
            for (size_t col = 0; col < num_columns; ++col)
448
            {
449
                auto & to_values_col = assert_cast<ColumnArray &>(to_tuple.getColumn(col + 1)).getData();
450
                if (elem.second[col].isNull())
451
                    to_values_col.insertDefault();
452
                else
453
                    to_values_col.insert(elem.second[col]);
454
            }
455
        }
456
    }
457

458
    bool keepKey(const Field & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
459
    String getName() const override { return Derived::getNameImpl(); }
460
};
461

462
template <bool overflow, bool tuple_argument>
463
class AggregateFunctionSumMap final :
464
    public AggregateFunctionMapBase<AggregateFunctionSumMap<overflow, tuple_argument>, FieldVisitorSum, overflow, tuple_argument, true>
465
{
466
private:
467
    using Self = AggregateFunctionSumMap<overflow, tuple_argument>;
468
    using Base = AggregateFunctionMapBase<Self, FieldVisitorSum, overflow, tuple_argument, true>;
469

470
public:
471
    AggregateFunctionSumMap(const DataTypePtr & keys_type_,
472
            DataTypes & values_types_, const DataTypes & argument_types_,
473
            const Array & params_)
474
        : Base{keys_type_, values_types_, argument_types_}
475
    {
476
        // The constructor accepts parameters to have a uniform interface with
477
        // sumMapFiltered, but this function doesn't have any parameters.
478
        assertNoParameters(getNameImpl(), params_);
479
    }
480

481
    static String getNameImpl()
482
    {
483
        if constexpr (overflow)
484
        {
485
            return "sumMapWithOverflow";
486
        }
487
        else
488
        {
489
            return "sumMap";
490
        }
491
    }
492

493
    bool keepKey(const Field &) const { return true; }
494
};
495

496

497
template <bool overflow, bool tuple_argument>
498
class AggregateFunctionSumMapFiltered final :
499
    public AggregateFunctionMapBase<
500
        AggregateFunctionSumMapFiltered<overflow, tuple_argument>,
501
        FieldVisitorSum,
502
        overflow,
503
        tuple_argument,
504
        true>
505
{
506
private:
507
    using Self = AggregateFunctionSumMapFiltered<overflow, tuple_argument>;
508
    using Base = AggregateFunctionMapBase<Self, FieldVisitorSum, overflow, tuple_argument, true>;
509

510
    using ContainerT = std::set<Field>;
511
    ContainerT keys_to_keep;
512

513
public:
514
    AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type_,
515
            const DataTypes & values_types_, const DataTypes & argument_types_,
516
            const Array & params_)
517
        : Base{keys_type_, values_types_, argument_types_}
518
    {
519
        if (params_.size() != 1)
520
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
521
                "Aggregate function '{}' requires exactly one parameter "
522
                "of Array type", getNameImpl());
523

524
        Array keys_to_keep_values;
525
        if (!params_.front().tryGet<Array>(keys_to_keep_values))
526
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
527
                "Aggregate function {} requires an Array as a parameter",
528
                getNameImpl());
529

530
        this->parameters = params_;
531

532
        for (const Field & f : keys_to_keep_values)
533
            keys_to_keep.emplace(f);
534
    }
535

536
    static String getNameImpl()
537
    {
538
        if constexpr (overflow)
539
        {
540
            return "sumMapFilteredWithOverflow";
541
        }
542
        else
543
        {
544
            return "sumMapFiltered";
545
        }
546
    }
547

548
    bool keepKey(const Field & key) const { return keys_to_keep.contains(key); }
549
};
550

551

552
/** Implements `Max` operation.
553
 *  Returns true if changed
554
 */
555
class FieldVisitorMax : public StaticVisitor<bool>
556
{
557
private:
558
    const Field & rhs;
559

560
    template <typename FieldType>
561
    bool compareImpl(FieldType & x) const
562
    {
563
        auto val = rhs.get<FieldType>();
564
        if (val > x)
565
        {
566
            x = val;
567
            return true;
568
        }
569

570
        return false;
571
    }
572

573
public:
574
    explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {}
575

576
    bool operator() (Null &) const
577
    {
578
        /// Do not update current value, skip nulls
579
        return false;
580
    }
581

582
    bool operator() (AggregateFunctionStateData &) const { throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot compare AggregateFunctionStates"); }
583

584
    bool operator() (Array & x) const { return compareImpl<Array>(x); }
585
    bool operator() (Tuple & x) const { return compareImpl<Tuple>(x); }
586
    template <typename T>
587
    bool operator() (DecimalField<T> & x) const { return compareImpl<DecimalField<T>>(x); }
588
    template <typename T>
589
    bool operator() (T & x) const { return compareImpl<T>(x); }
590
};
591

592
/** Implements `Min` operation.
593
 *  Returns true if changed
594
 */
595
class FieldVisitorMin : public StaticVisitor<bool>
596
{
597
private:
598
    const Field & rhs;
599

600
    template <typename FieldType>
601
    bool compareImpl(FieldType & x) const
602
    {
603
        auto val = rhs.get<FieldType>();
604
        if (val < x)
605
        {
606
            x = val;
607
            return true;
608
        }
609

610
        return false;
611
    }
612

613
public:
614
    explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {}
615

616

617
    bool operator() (Null &) const
618
    {
619
        /// Do not update current value, skip nulls
620
        return false;
621
    }
622

623
    bool operator() (AggregateFunctionStateData &) const { throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot sum AggregateFunctionStates"); }
624

625
    bool operator() (Array & x) const { return compareImpl<Array>(x); }
626
    bool operator() (Tuple & x) const { return compareImpl<Tuple>(x); }
627
    template <typename T>
628
    bool operator() (DecimalField<T> & x) const { return compareImpl<DecimalField<T>>(x); }
629
    template <typename T>
630
    bool operator() (T & x) const { return compareImpl<T>(x); }
631
};
632

633

634
template <bool tuple_argument>
635
class AggregateFunctionMinMap final :
636
    public AggregateFunctionMapBase<AggregateFunctionMinMap<tuple_argument>, FieldVisitorMin, true, tuple_argument, false>
637
{
638
private:
639
    using Self = AggregateFunctionMinMap<tuple_argument>;
640
    using Base = AggregateFunctionMapBase<Self, FieldVisitorMin, true, tuple_argument, false>;
641

642
public:
643
    AggregateFunctionMinMap(const DataTypePtr & keys_type_,
644
            DataTypes & values_types_, const DataTypes & argument_types_,
645
            const Array & params_)
646
        : Base{keys_type_, values_types_, argument_types_}
647
    {
648
        // The constructor accepts parameters to have a uniform interface with
649
        // sumMapFiltered, but this function doesn't have any parameters.
650
        assertNoParameters(getNameImpl(), params_);
651
    }
652

653
    static String getNameImpl() { return "minMap"; }
654

655
    bool keepKey(const Field &) const { return true; }
656
};
657

658
template <bool tuple_argument>
659
class AggregateFunctionMaxMap final :
660
    public AggregateFunctionMapBase<AggregateFunctionMaxMap<tuple_argument>, FieldVisitorMax, true, tuple_argument, false>
661
{
662
private:
663
    using Self = AggregateFunctionMaxMap<tuple_argument>;
664
    using Base = AggregateFunctionMapBase<Self, FieldVisitorMax, true, tuple_argument, false>;
665

666
public:
667
    AggregateFunctionMaxMap(const DataTypePtr & keys_type_,
668
            DataTypes & values_types_, const DataTypes & argument_types_,
669
            const Array & params_)
670
        : Base{keys_type_, values_types_, argument_types_}
671
    {
672
        // The constructor accepts parameters to have a uniform interface with
673
        // sumMapFiltered, but this function doesn't have any parameters.
674
        assertNoParameters(getNameImpl(), params_);
675
    }
676

677
    static String getNameImpl() { return "maxMap"; }
678

679
    bool keepKey(const Field &) const { return true; }
680
};
681

682

683
auto parseArguments(const std::string & name, const DataTypes & arguments)
684
{
685
    DataTypes args;
686
    bool tuple_argument = false;
687

688
    if (arguments.size() == 1)
689
    {
690
        // sumMap state is fully given by its result, so it can be stored in
691
        // SimpleAggregateFunction columns. There is a caveat: it must support
692
        // sumMap(sumMap(...)), e.g. it must be able to accept its own output as
693
        // an input. This is why it also accepts a Tuple(keys, values) argument.
694
        const auto * tuple_type = checkAndGetDataType<DataTypeTuple>(arguments[0].get());
695
        if (!tuple_type)
696
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "When function {} gets one argument it must be a tuple", name);
697

698
        const auto elems = tuple_type->getElements();
699
        args.insert(args.end(), elems.begin(), elems.end());
700
        tuple_argument = true;
701
    }
702
    else
703
    {
704
        args.insert(args.end(), arguments.begin(), arguments.end());
705
        tuple_argument = false;
706
    }
707

708
    if (args.size() < 2)
709
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
710
            "Aggregate function {} requires at least two arguments of Array type or one argument of tuple of two arrays", name);
711

712
    const auto * array_type = checkAndGetDataType<DataTypeArray>(args[0].get());
713
    if (!array_type)
714
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be an array, not {}",
715
            name, args[0]->getName());
716

717
    DataTypePtr keys_type = array_type->getNestedType();
718

719
    DataTypes values_types;
720
    values_types.reserve(args.size() - 1);
721
    for (size_t i = 1; i < args.size(); ++i)
722
    {
723
        array_type = checkAndGetDataType<DataTypeArray>(args[i].get());
724
        if (!array_type)
725
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument #{} for function {} must be an array.",
726
                i, name);
727
        values_types.push_back(array_type->getNestedType());
728
    }
729

730
    return std::tuple<DataTypePtr, DataTypes, bool>{std::move(keys_type), std::move(values_types), tuple_argument};
731
}
732

733
}
734

735
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
736
{
737
    // these functions used to be called *Map, with now these names occupied by
738
    // Map combinator, which redirects calls here if was called with
739
    // array or tuple arguments.
740
    factory.registerFunction("sumMappedArrays", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
741
    {
742
        auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
743
        if (tuple_argument)
744
            return std::make_shared<AggregateFunctionSumMap<false, true>>(keys_type, values_types, arguments, params);
745
        else
746
            return std::make_shared<AggregateFunctionSumMap<false, false>>(keys_type, values_types, arguments, params);
747
    });
748

749
    factory.registerFunction("minMappedArrays", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
750
    {
751
        auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
752
        if (tuple_argument)
753
            return std::make_shared<AggregateFunctionMinMap<true>>(keys_type, values_types, arguments, params);
754
        else
755
            return std::make_shared<AggregateFunctionMinMap<false>>(keys_type, values_types, arguments, params);
756
    });
757

758
    factory.registerFunction("maxMappedArrays", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
759
    {
760
        auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
761
        if (tuple_argument)
762
            return std::make_shared<AggregateFunctionMaxMap<true>>(keys_type, values_types, arguments, params);
763
        else
764
            return std::make_shared<AggregateFunctionMaxMap<false>>(keys_type, values_types, arguments, params);
765
    });
766

767
    // these functions could be renamed to *MappedArrays too, but it would
768
    // break backward compatibility
769
    factory.registerFunction("sumMapWithOverflow", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
770
    {
771
        auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
772
        if (tuple_argument)
773
            return std::make_shared<AggregateFunctionSumMap<true, true>>(keys_type, values_types, arguments, params);
774
        else
775
            return std::make_shared<AggregateFunctionSumMap<true, false>>(keys_type, values_types, arguments, params);
776
    });
777

778
    factory.registerFunction("sumMapFiltered", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
779
    {
780
        auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
781
        if (tuple_argument)
782
            return std::make_shared<AggregateFunctionSumMapFiltered<false, true>>(keys_type, values_types, arguments, params);
783
        else
784
            return std::make_shared<AggregateFunctionSumMapFiltered<false, false>>(keys_type, values_types, arguments, params);
785
    });
786

787
    factory.registerFunction("sumMapFilteredWithOverflow", [](const std::string & name, const DataTypes & arguments, const Array & params, const Settings *) -> AggregateFunctionPtr
788
    {
789
        auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
790
        if (tuple_argument)
791
            return std::make_shared<AggregateFunctionSumMapFiltered<true, true>>(keys_type, values_types, arguments, params);
792
        else
793
            return std::make_shared<AggregateFunctionSumMapFiltered<true, false>>(keys_type, values_types, arguments, params);
794
    });
795
}
796

797
}
798

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.