ClickHouse

Форк
0
/
AggregateFunctionMap.cpp 
467 строк · 17.2 Кб
1
#include <unordered_map>
2
#include <AggregateFunctions/AggregateFunctionFactory.h>
3
#include <AggregateFunctions/IAggregateFunction.h>
4
#include <Columns/ColumnFixedString.h>
5
#include <Columns/ColumnMap.h>
6
#include <Columns/ColumnString.h>
7
#include <Columns/ColumnTuple.h>
8
#include <Columns/ColumnVector.h>
9
#include <DataTypes/DataTypeArray.h>
10
#include <DataTypes/DataTypeMap.h>
11
#include <DataTypes/DataTypeTuple.h>
12
#include <DataTypes/DataTypesNumber.h>
13
#include <Functions/FunctionHelpers.h>
14
#include <IO/ReadHelpers.h>
15
#include <IO/WriteHelpers.h>
16
#include <Common/Arena.h>
17
#include "AggregateFunctionCombinatorFactory.h"
18

19

20
namespace DB
21
{
22

23
namespace ErrorCodes
24
{
25
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
26
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
27
}
28

29
namespace
30
{
31

32
template <typename KeyType>
33
struct AggregateFunctionMapCombinatorData
34
{
35
    using SearchType = KeyType;
36
    std::unordered_map<KeyType, AggregateDataPtr> merged_maps;
37

38
    static void writeKey(KeyType key, WriteBuffer & buf) { writeBinaryLittleEndian(key, buf); }
39
    static void readKey(KeyType & key, ReadBuffer & buf) { readBinaryLittleEndian(key, buf); }
40
};
41

42
template <>
43
struct AggregateFunctionMapCombinatorData<String>
44
{
45
    struct StringHash
46
    {
47
        using hash_type = std::hash<std::string_view>;
48
        using is_transparent = void;
49

50
        size_t operator()(std::string_view str) const { return hash_type{}(str); }
51
    };
52

53
    using SearchType = std::string_view;
54
    std::unordered_map<String, AggregateDataPtr, StringHash, std::equal_to<>> merged_maps;
55

56
    static void writeKey(String key, WriteBuffer & buf)
57
    {
58
        writeStringBinary(key, buf);
59
    }
60
    static void readKey(String & key, ReadBuffer & buf)
61
    {
62
        readStringBinary(key, buf);
63
    }
64
};
65

66
/// Specialization for IPv6 - for historical reasons it should be stored as FixedString(16)
67
template <>
68
struct AggregateFunctionMapCombinatorData<IPv6>
69
{
70
    struct IPv6Hash
71
    {
72
        using hash_type = std::hash<IPv6>;
73
        using is_transparent = void;
74

75
        size_t operator()(const IPv6 & ip) const { return hash_type{}(ip); }
76
    };
77

78
    using SearchType = IPv6;
79
    std::unordered_map<IPv6, AggregateDataPtr, IPv6Hash, std::equal_to<>> merged_maps;
80

81
    static void writeKey(const IPv6 & key, WriteBuffer & buf)
82
    {
83
        writeIPv6Binary(key, buf);
84
    }
85
    static void readKey(IPv6 & key, ReadBuffer & buf)
86
    {
87
        readIPv6Binary(key, buf);
88
    }
89
};
90

91
template <typename KeyType>
92
class AggregateFunctionMap final
93
    : public IAggregateFunctionDataHelper<AggregateFunctionMapCombinatorData<KeyType>, AggregateFunctionMap<KeyType>>
94
{
95
private:
96
    DataTypePtr key_type;
97
    AggregateFunctionPtr nested_func;
98

99
    using Data = AggregateFunctionMapCombinatorData<KeyType>;
100
    using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionMap<KeyType>>;
101

102
public:
103
    bool isState() const override
104
    {
105
        return nested_func->isState();
106
    }
107

108
    bool isVersioned() const override
109
    {
110
        return nested_func->isVersioned();
111
    }
112

113
    size_t getVersionFromRevision(size_t revision) const override
114
    {
115
        return nested_func->getVersionFromRevision(revision);
116
    }
117

118
    size_t getDefaultVersion() const override
119
    {
120
        return nested_func->getDefaultVersion();
121
    }
122

123
    AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types)
124
        : Base(types, nested->getParameters(), std::make_shared<DataTypeMap>(DataTypes{getKeyType(types, nested), nested->getResultType()}))
125
        , nested_func(nested)
126
    {
127
        key_type = getKeyType(types, nested_func);
128
    }
129

130
    String getName() const override { return nested_func->getName() + "Map"; }
131

132
    static DataTypePtr getKeyType(const DataTypes & types, const AggregateFunctionPtr & nested)
133
    {
134
        if (types.size() != 1)
135
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
136
            "Aggregate function {}Map requires one map argument, but {} found", nested->getName(), types.size());
137

138
        const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());
139
        if (!map_type)
140
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
141
            "Aggregate function {}Map requires map as argument", nested->getName());
142

143
        return map_type->getKeyType();
144
    }
145

146
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
147
    {
148
        const auto & map_column = assert_cast<const ColumnMap &>(*columns[0]);
149
        const auto & map_nested_tuple = map_column.getNestedData();
150
        const IColumn::Offsets & map_array_offsets = map_column.getNestedColumn().getOffsets();
151

152
        const size_t offset = map_array_offsets[row_num - 1];
153
        const size_t size = (map_array_offsets[row_num] - offset);
154

155
        const auto & key_column = map_nested_tuple.getColumn(0);
156
        const auto & val_column = map_nested_tuple.getColumn(1);
157

158
        auto & merged_maps = this->data(place).merged_maps;
159

160
        for (size_t i = 0; i < size; ++i)
161
        {
162
            typename Data::SearchType key;
163

164
            if constexpr (std::is_same_v<KeyType, String>)
165
            {
166
                StringRef key_ref;
167
                if (key_type->getTypeId() == TypeIndex::FixedString)
168
                    key_ref = assert_cast<const ColumnFixedString &>(key_column).getDataAt(offset + i);
169
                else if (key_type->getTypeId() == TypeIndex::IPv6)
170
                    key_ref = assert_cast<const ColumnIPv6 &>(key_column).getDataAt(offset + i);
171
                else
172
                    key_ref = assert_cast<const ColumnString &>(key_column).getDataAt(offset + i);
173

174
                key = key_ref.toView();
175
            }
176
            else
177
            {
178
                key = assert_cast<const ColumnVector<KeyType> &>(key_column).getData()[offset + i];
179
            }
180

181
            AggregateDataPtr nested_place;
182
            auto it = merged_maps.find(key);
183

184
            if (it == merged_maps.end())
185
            {
186
                // create a new place for each key
187
                nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData());
188
                nested_func->create(nested_place);
189
                merged_maps.emplace(key, nested_place);
190
            }
191
            else
192
                nested_place = it->second;
193

194
            const IColumn * nested_columns[1] = {&val_column};
195
            nested_func->add(nested_place, nested_columns, offset + i, arena);
196
        }
197
    }
198

199
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
200
    {
201
        auto & merged_maps = this->data(place).merged_maps;
202
        const auto & rhs_maps = this->data(rhs).merged_maps;
203

204
        for (const auto & elem : rhs_maps)
205
        {
206
            const auto & it = merged_maps.find(elem.first);
207

208
            AggregateDataPtr nested_place;
209
            if (it == merged_maps.end())
210
            {
211
                // elem.second cannot be copied since this it will be destroyed after merging,
212
                // and lead to use-after-free.
213
                nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData());
214
                nested_func->create(nested_place);
215
                merged_maps.emplace(elem.first, nested_place);
216
            }
217
            else
218
            {
219
                nested_place = it->second;
220
            }
221

222
            nested_func->merge(nested_place, elem.second, arena);
223
        }
224
    }
225

226
    template <bool up_to_state>
227
    void destroyImpl(AggregateDataPtr __restrict place) const noexcept
228
    {
229
        AggregateFunctionMapCombinatorData<KeyType> & state = Base::data(place);
230

231
        for (const auto & [key, nested_place] : state.merged_maps)
232
        {
233
            if constexpr (up_to_state)
234
                nested_func->destroyUpToState(nested_place);
235
            else
236
                nested_func->destroy(nested_place);
237
        }
238

239
        state.~Data();
240
    }
241

242
    void destroy(AggregateDataPtr __restrict place) const noexcept override
243
    {
244
        destroyImpl<false>(place);
245
    }
246

247
    bool hasTrivialDestructor() const override
248
    {
249
        return std::is_trivially_destructible_v<Data> && nested_func->hasTrivialDestructor();
250
    }
251

252
    void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override
253
    {
254
        destroyImpl<true>(place);
255
    }
256

257
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
258
    {
259
        auto & merged_maps = this->data(place).merged_maps;
260
        writeVarUInt(merged_maps.size(), buf);
261

262
        for (const auto & elem : merged_maps)
263
        {
264
            this->data(place).writeKey(elem.first, buf);
265
            nested_func->serialize(elem.second, buf);
266
        }
267
    }
268

269
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
270
    {
271
        auto & merged_maps = this->data(place).merged_maps;
272
        UInt64 size;
273

274
        readVarUInt(size, buf);
275
        for (UInt64 i = 0; i < size; ++i)
276
        {
277
            KeyType key;
278
            AggregateDataPtr nested_place;
279

280
            this->data(place).readKey(key, buf);
281
            nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData());
282
            nested_func->create(nested_place);
283
            merged_maps.emplace(key, nested_place);
284
            nested_func->deserialize(nested_place, buf, std::nullopt, arena);
285
        }
286
    }
287

288
    template <bool merge>
289
    void insertResultIntoImpl(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const
290
    {
291
        auto & map_column = assert_cast<ColumnMap &>(to);
292
        auto & nested_column = map_column.getNestedColumn();
293
        auto & nested_data_column = map_column.getNestedData();
294

295
        auto & key_column = nested_data_column.getColumn(0);
296
        auto & val_column = nested_data_column.getColumn(1);
297

298
        auto & merged_maps = this->data(place).merged_maps;
299

300
        // sort the keys
301
        std::vector<KeyType> keys;
302
        keys.reserve(merged_maps.size());
303
        for (auto & it : merged_maps)
304
        {
305
            keys.push_back(it.first);
306
        }
307
        ::sort(keys.begin(), keys.end());
308

309
        // insert using sorted keys to result column
310
        for (auto & key : keys)
311
        {
312
            key_column.insert(key);
313
            if constexpr (merge)
314
                nested_func->insertMergeResultInto(merged_maps[key], val_column, arena);
315
            else
316
                nested_func->insertResultInto(merged_maps[key], val_column, arena);
317
        }
318

319
        IColumn::Offsets & res_offsets = nested_column.getOffsets();
320
        res_offsets.push_back(val_column.size());
321
    }
322

323
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
324
    {
325
        insertResultIntoImpl<false>(place, to, arena);
326
    }
327

328
    void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
329
    {
330
        insertResultIntoImpl<true>(place, to, arena);
331
    }
332

333
    bool allocatesMemoryInArena() const override { return true; }
334

335
    AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
336
};
337

338

339
class AggregateFunctionCombinatorMap final : public IAggregateFunctionCombinator
340
{
341
public:
342
    String getName() const override { return "Map"; }
343

344
    DataTypes transformArguments(const DataTypes & arguments) const override
345
    {
346
        if (arguments.empty())
347
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
348
                "Incorrect number of arguments for aggregate function with {} suffix", getName());
349

350
        const auto * map_type = checkAndGetDataType<DataTypeMap>(arguments[0].get());
351
        if (map_type)
352
        {
353
            if (arguments.size() > 1)
354
                throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "{} combinator takes only one map argument", getName());
355

356
            return DataTypes({map_type->getValueType()});
357
        }
358

359
        // we need this part just to pass to redirection for mapped arrays
360
        auto check_func = [](DataTypePtr t) { return t->getTypeId() == TypeIndex::Array; };
361

362
        const auto * tup_type = checkAndGetDataType<DataTypeTuple>(arguments[0].get());
363
        if (tup_type)
364
        {
365
            const auto & types = tup_type->getElements();
366
            bool arrays_match = arguments.size() == 1 && types.size() >= 2 && std::all_of(types.begin(), types.end(), check_func);
367
            if (arrays_match)
368
            {
369
                const auto * val_array_type = assert_cast<const DataTypeArray *>(types[1].get());
370
                return DataTypes({val_array_type->getNestedType()});
371
            }
372
        }
373
        else
374
        {
375
            bool arrays_match = arguments.size() >= 2 && std::all_of(arguments.begin(), arguments.end(), check_func);
376
            if (arrays_match)
377
            {
378
                const auto * val_array_type = assert_cast<const DataTypeArray *>(arguments[1].get());
379
                return DataTypes({val_array_type->getNestedType()});
380
            }
381
        }
382

383
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} requires map as argument", getName());
384
    }
385

386
    AggregateFunctionPtr transformAggregateFunction(
387
        const AggregateFunctionPtr & nested_function,
388
        const AggregateFunctionProperties &,
389
        const DataTypes & arguments,
390
        const Array & params) const override
391
    {
392
        const auto * map_type = checkAndGetDataType<DataTypeMap>(arguments[0].get());
393
        if (map_type)
394
        {
395
            const auto & key_type = map_type->getKeyType();
396

397
            switch (key_type->getTypeId())
398
            {
399
                case TypeIndex::Enum8:
400
                case TypeIndex::Int8:
401
                    return std::make_shared<AggregateFunctionMap<Int8>>(nested_function, arguments);
402
                case TypeIndex::Enum16:
403
                case TypeIndex::Int16:
404
                    return std::make_shared<AggregateFunctionMap<Int16>>(nested_function, arguments);
405
                case TypeIndex::Int32:
406
                    return std::make_shared<AggregateFunctionMap<Int32>>(nested_function, arguments);
407
                case TypeIndex::Int64:
408
                    return std::make_shared<AggregateFunctionMap<Int64>>(nested_function, arguments);
409
                case TypeIndex::Int128:
410
                    return std::make_shared<AggregateFunctionMap<Int128>>(nested_function, arguments);
411
                case TypeIndex::Int256:
412
                    return std::make_shared<AggregateFunctionMap<Int256>>(nested_function, arguments);
413
                case TypeIndex::UInt8:
414
                    return std::make_shared<AggregateFunctionMap<UInt8>>(nested_function, arguments);
415
                case TypeIndex::Date:
416
                case TypeIndex::UInt16:
417
                    return std::make_shared<AggregateFunctionMap<UInt16>>(nested_function, arguments);
418
                case TypeIndex::DateTime:
419
                case TypeIndex::UInt32:
420
                    return std::make_shared<AggregateFunctionMap<UInt32>>(nested_function, arguments);
421
                case TypeIndex::UInt64:
422
                    return std::make_shared<AggregateFunctionMap<UInt64>>(nested_function, arguments);
423
                case TypeIndex::UInt128:
424
                    return std::make_shared<AggregateFunctionMap<UInt128>>(nested_function, arguments);
425
                case TypeIndex::UInt256:
426
                    return std::make_shared<AggregateFunctionMap<UInt256>>(nested_function, arguments);
427
                case TypeIndex::UUID:
428
                    return std::make_shared<AggregateFunctionMap<UUID>>(nested_function, arguments);
429
                case TypeIndex::IPv4:
430
                    return std::make_shared<AggregateFunctionMap<IPv4>>(nested_function, arguments);
431
                case TypeIndex::IPv6:
432
                    return std::make_shared<AggregateFunctionMap<IPv6>>(nested_function, arguments);
433
                case TypeIndex::FixedString:
434
                case TypeIndex::String:
435
                    return std::make_shared<AggregateFunctionMap<String>>(nested_function, arguments);
436
                default:
437
                    throw Exception(
438
                        ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
439
                        "Map key type {} is not is not supported by combinator {}", key_type->getName(), getName());
440
            }
441
        }
442
        else
443
        {
444
            // in case of tuple of arrays or just arrays (checked in transformArguments), try to redirect to sum/min/max-MappedArrays to implement old behavior
445
            auto nested_func_name = nested_function->getName();
446
            if (nested_func_name == "sum" || nested_func_name == "min" || nested_func_name == "max")
447
            {
448
                AggregateFunctionProperties out_properties;
449
                auto & aggr_func_factory = AggregateFunctionFactory::instance();
450
                auto action = NullsAction::EMPTY;
451
                return aggr_func_factory.get(nested_func_name + "MappedArrays", action, arguments, params, out_properties);
452
            }
453
            else
454
                throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregation '{}Map' is not implemented for mapped arrays",
455
                                 nested_func_name);
456
        }
457
    }
458
};
459

460
}
461

462
void registerAggregateFunctionCombinatorMap(AggregateFunctionCombinatorFactory & factory)
463
{
464
    factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorMap>());
465
}
466

467
}
468

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.