ClickHouse
467 строк · 17.2 Кб
1#include <unordered_map>2#include <AggregateFunctions/AggregateFunctionFactory.h>3#include <AggregateFunctions/IAggregateFunction.h>4#include <Columns/ColumnFixedString.h>5#include <Columns/ColumnMap.h>6#include <Columns/ColumnString.h>7#include <Columns/ColumnTuple.h>8#include <Columns/ColumnVector.h>9#include <DataTypes/DataTypeArray.h>10#include <DataTypes/DataTypeMap.h>11#include <DataTypes/DataTypeTuple.h>12#include <DataTypes/DataTypesNumber.h>13#include <Functions/FunctionHelpers.h>14#include <IO/ReadHelpers.h>15#include <IO/WriteHelpers.h>16#include <Common/Arena.h>17#include "AggregateFunctionCombinatorFactory.h"18
19
20namespace DB21{
22
23namespace ErrorCodes24{
25extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;26extern const int ILLEGAL_TYPE_OF_ARGUMENT;27}
28
29namespace
30{
31
32template <typename KeyType>33struct AggregateFunctionMapCombinatorData34{
35using SearchType = KeyType;36std::unordered_map<KeyType, AggregateDataPtr> merged_maps;37
38static void writeKey(KeyType key, WriteBuffer & buf) { writeBinaryLittleEndian(key, buf); }39static void readKey(KeyType & key, ReadBuffer & buf) { readBinaryLittleEndian(key, buf); }40};41
42template <>43struct AggregateFunctionMapCombinatorData<String>44{
45struct StringHash46{47using hash_type = std::hash<std::string_view>;48using is_transparent = void;49
50size_t operator()(std::string_view str) const { return hash_type{}(str); }51};52
53using SearchType = std::string_view;54std::unordered_map<String, AggregateDataPtr, StringHash, std::equal_to<>> merged_maps;55
56static void writeKey(String key, WriteBuffer & buf)57{58writeStringBinary(key, buf);59}60static void readKey(String & key, ReadBuffer & buf)61{62readStringBinary(key, buf);63}64};65
66/// Specialization for IPv6 - for historical reasons it should be stored as FixedString(16)
67template <>68struct AggregateFunctionMapCombinatorData<IPv6>69{
70struct IPv6Hash71{72using hash_type = std::hash<IPv6>;73using is_transparent = void;74
75size_t operator()(const IPv6 & ip) const { return hash_type{}(ip); }76};77
78using SearchType = IPv6;79std::unordered_map<IPv6, AggregateDataPtr, IPv6Hash, std::equal_to<>> merged_maps;80
81static void writeKey(const IPv6 & key, WriteBuffer & buf)82{83writeIPv6Binary(key, buf);84}85static void readKey(IPv6 & key, ReadBuffer & buf)86{87readIPv6Binary(key, buf);88}89};90
91template <typename KeyType>92class AggregateFunctionMap final93: public IAggregateFunctionDataHelper<AggregateFunctionMapCombinatorData<KeyType>, AggregateFunctionMap<KeyType>>94{
95private:96DataTypePtr key_type;97AggregateFunctionPtr nested_func;98
99using Data = AggregateFunctionMapCombinatorData<KeyType>;100using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionMap<KeyType>>;101
102public:103bool isState() const override104{105return nested_func->isState();106}107
108bool isVersioned() const override109{110return nested_func->isVersioned();111}112
113size_t getVersionFromRevision(size_t revision) const override114{115return nested_func->getVersionFromRevision(revision);116}117
118size_t getDefaultVersion() const override119{120return nested_func->getDefaultVersion();121}122
123AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types)124: Base(types, nested->getParameters(), std::make_shared<DataTypeMap>(DataTypes{getKeyType(types, nested), nested->getResultType()}))125, nested_func(nested)126{127key_type = getKeyType(types, nested_func);128}129
130String getName() const override { return nested_func->getName() + "Map"; }131
132static DataTypePtr getKeyType(const DataTypes & types, const AggregateFunctionPtr & nested)133{134if (types.size() != 1)135throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,136"Aggregate function {}Map requires one map argument, but {} found", nested->getName(), types.size());137
138const auto * map_type = checkAndGetDataType<DataTypeMap>(types[0].get());139if (!map_type)140throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,141"Aggregate function {}Map requires map as argument", nested->getName());142
143return map_type->getKeyType();144}145
146void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override147{148const auto & map_column = assert_cast<const ColumnMap &>(*columns[0]);149const auto & map_nested_tuple = map_column.getNestedData();150const IColumn::Offsets & map_array_offsets = map_column.getNestedColumn().getOffsets();151
152const size_t offset = map_array_offsets[row_num - 1];153const size_t size = (map_array_offsets[row_num] - offset);154
155const auto & key_column = map_nested_tuple.getColumn(0);156const auto & val_column = map_nested_tuple.getColumn(1);157
158auto & merged_maps = this->data(place).merged_maps;159
160for (size_t i = 0; i < size; ++i)161{162typename Data::SearchType key;163
164if constexpr (std::is_same_v<KeyType, String>)165{166StringRef key_ref;167if (key_type->getTypeId() == TypeIndex::FixedString)168key_ref = assert_cast<const ColumnFixedString &>(key_column).getDataAt(offset + i);169else if (key_type->getTypeId() == TypeIndex::IPv6)170key_ref = assert_cast<const ColumnIPv6 &>(key_column).getDataAt(offset + i);171else172key_ref = assert_cast<const ColumnString &>(key_column).getDataAt(offset + i);173
174key = key_ref.toView();175}176else177{178key = assert_cast<const ColumnVector<KeyType> &>(key_column).getData()[offset + i];179}180
181AggregateDataPtr nested_place;182auto it = merged_maps.find(key);183
184if (it == merged_maps.end())185{186// create a new place for each key187nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData());188nested_func->create(nested_place);189merged_maps.emplace(key, nested_place);190}191else192nested_place = it->second;193
194const IColumn * nested_columns[1] = {&val_column};195nested_func->add(nested_place, nested_columns, offset + i, arena);196}197}198
199void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override200{201auto & merged_maps = this->data(place).merged_maps;202const auto & rhs_maps = this->data(rhs).merged_maps;203
204for (const auto & elem : rhs_maps)205{206const auto & it = merged_maps.find(elem.first);207
208AggregateDataPtr nested_place;209if (it == merged_maps.end())210{211// elem.second cannot be copied since this it will be destroyed after merging,212// and lead to use-after-free.213nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData());214nested_func->create(nested_place);215merged_maps.emplace(elem.first, nested_place);216}217else218{219nested_place = it->second;220}221
222nested_func->merge(nested_place, elem.second, arena);223}224}225
226template <bool up_to_state>227void destroyImpl(AggregateDataPtr __restrict place) const noexcept228{229AggregateFunctionMapCombinatorData<KeyType> & state = Base::data(place);230
231for (const auto & [key, nested_place] : state.merged_maps)232{233if constexpr (up_to_state)234nested_func->destroyUpToState(nested_place);235else236nested_func->destroy(nested_place);237}238
239state.~Data();240}241
242void destroy(AggregateDataPtr __restrict place) const noexcept override243{244destroyImpl<false>(place);245}246
247bool hasTrivialDestructor() const override248{249return std::is_trivially_destructible_v<Data> && nested_func->hasTrivialDestructor();250}251
252void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override253{254destroyImpl<true>(place);255}256
257void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override258{259auto & merged_maps = this->data(place).merged_maps;260writeVarUInt(merged_maps.size(), buf);261
262for (const auto & elem : merged_maps)263{264this->data(place).writeKey(elem.first, buf);265nested_func->serialize(elem.second, buf);266}267}268
269void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override270{271auto & merged_maps = this->data(place).merged_maps;272UInt64 size;273
274readVarUInt(size, buf);275for (UInt64 i = 0; i < size; ++i)276{277KeyType key;278AggregateDataPtr nested_place;279
280this->data(place).readKey(key, buf);281nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData());282nested_func->create(nested_place);283merged_maps.emplace(key, nested_place);284nested_func->deserialize(nested_place, buf, std::nullopt, arena);285}286}287
288template <bool merge>289void insertResultIntoImpl(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const290{291auto & map_column = assert_cast<ColumnMap &>(to);292auto & nested_column = map_column.getNestedColumn();293auto & nested_data_column = map_column.getNestedData();294
295auto & key_column = nested_data_column.getColumn(0);296auto & val_column = nested_data_column.getColumn(1);297
298auto & merged_maps = this->data(place).merged_maps;299
300// sort the keys301std::vector<KeyType> keys;302keys.reserve(merged_maps.size());303for (auto & it : merged_maps)304{305keys.push_back(it.first);306}307::sort(keys.begin(), keys.end());308
309// insert using sorted keys to result column310for (auto & key : keys)311{312key_column.insert(key);313if constexpr (merge)314nested_func->insertMergeResultInto(merged_maps[key], val_column, arena);315else316nested_func->insertResultInto(merged_maps[key], val_column, arena);317}318
319IColumn::Offsets & res_offsets = nested_column.getOffsets();320res_offsets.push_back(val_column.size());321}322
323void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override324{325insertResultIntoImpl<false>(place, to, arena);326}327
328void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override329{330insertResultIntoImpl<true>(place, to, arena);331}332
333bool allocatesMemoryInArena() const override { return true; }334
335AggregateFunctionPtr getNestedFunction() const override { return nested_func; }336};337
338
339class AggregateFunctionCombinatorMap final : public IAggregateFunctionCombinator340{
341public:342String getName() const override { return "Map"; }343
344DataTypes transformArguments(const DataTypes & arguments) const override345{346if (arguments.empty())347throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,348"Incorrect number of arguments for aggregate function with {} suffix", getName());349
350const auto * map_type = checkAndGetDataType<DataTypeMap>(arguments[0].get());351if (map_type)352{353if (arguments.size() > 1)354throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "{} combinator takes only one map argument", getName());355
356return DataTypes({map_type->getValueType()});357}358
359// we need this part just to pass to redirection for mapped arrays360auto check_func = [](DataTypePtr t) { return t->getTypeId() == TypeIndex::Array; };361
362const auto * tup_type = checkAndGetDataType<DataTypeTuple>(arguments[0].get());363if (tup_type)364{365const auto & types = tup_type->getElements();366bool arrays_match = arguments.size() == 1 && types.size() >= 2 && std::all_of(types.begin(), types.end(), check_func);367if (arrays_match)368{369const auto * val_array_type = assert_cast<const DataTypeArray *>(types[1].get());370return DataTypes({val_array_type->getNestedType()});371}372}373else374{375bool arrays_match = arguments.size() >= 2 && std::all_of(arguments.begin(), arguments.end(), check_func);376if (arrays_match)377{378const auto * val_array_type = assert_cast<const DataTypeArray *>(arguments[1].get());379return DataTypes({val_array_type->getNestedType()});380}381}382
383throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} requires map as argument", getName());384}385
386AggregateFunctionPtr transformAggregateFunction(387const AggregateFunctionPtr & nested_function,388const AggregateFunctionProperties &,389const DataTypes & arguments,390const Array & params) const override391{392const auto * map_type = checkAndGetDataType<DataTypeMap>(arguments[0].get());393if (map_type)394{395const auto & key_type = map_type->getKeyType();396
397switch (key_type->getTypeId())398{399case TypeIndex::Enum8:400case TypeIndex::Int8:401return std::make_shared<AggregateFunctionMap<Int8>>(nested_function, arguments);402case TypeIndex::Enum16:403case TypeIndex::Int16:404return std::make_shared<AggregateFunctionMap<Int16>>(nested_function, arguments);405case TypeIndex::Int32:406return std::make_shared<AggregateFunctionMap<Int32>>(nested_function, arguments);407case TypeIndex::Int64:408return std::make_shared<AggregateFunctionMap<Int64>>(nested_function, arguments);409case TypeIndex::Int128:410return std::make_shared<AggregateFunctionMap<Int128>>(nested_function, arguments);411case TypeIndex::Int256:412return std::make_shared<AggregateFunctionMap<Int256>>(nested_function, arguments);413case TypeIndex::UInt8:414return std::make_shared<AggregateFunctionMap<UInt8>>(nested_function, arguments);415case TypeIndex::Date:416case TypeIndex::UInt16:417return std::make_shared<AggregateFunctionMap<UInt16>>(nested_function, arguments);418case TypeIndex::DateTime:419case TypeIndex::UInt32:420return std::make_shared<AggregateFunctionMap<UInt32>>(nested_function, arguments);421case TypeIndex::UInt64:422return std::make_shared<AggregateFunctionMap<UInt64>>(nested_function, arguments);423case TypeIndex::UInt128:424return std::make_shared<AggregateFunctionMap<UInt128>>(nested_function, arguments);425case TypeIndex::UInt256:426return std::make_shared<AggregateFunctionMap<UInt256>>(nested_function, arguments);427case TypeIndex::UUID:428return std::make_shared<AggregateFunctionMap<UUID>>(nested_function, arguments);429case TypeIndex::IPv4:430return std::make_shared<AggregateFunctionMap<IPv4>>(nested_function, arguments);431case TypeIndex::IPv6:432return std::make_shared<AggregateFunctionMap<IPv6>>(nested_function, arguments);433case TypeIndex::FixedString:434case TypeIndex::String:435return std::make_shared<AggregateFunctionMap<String>>(nested_function, arguments);436default:437throw Exception(438ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,439"Map key type {} is not is not supported by combinator {}", key_type->getName(), getName());440}441}442else443{444// in case of tuple of arrays or just arrays (checked in transformArguments), try to redirect to sum/min/max-MappedArrays to implement old behavior445auto nested_func_name = nested_function->getName();446if (nested_func_name == "sum" || nested_func_name == "min" || nested_func_name == "max")447{448AggregateFunctionProperties out_properties;449auto & aggr_func_factory = AggregateFunctionFactory::instance();450auto action = NullsAction::EMPTY;451return aggr_func_factory.get(nested_func_name + "MappedArrays", action, arguments, params, out_properties);452}453else454throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregation '{}Map' is not implemented for mapped arrays",455nested_func_name);456}457}458};459
460}
461
462void registerAggregateFunctionCombinatorMap(AggregateFunctionCombinatorFactory & factory)463{
464factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorMap>());465}
466
467}
468