ClickHouse
212 строк · 7.4 Кб
1#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
2#include <AggregateFunctions/SingleValueData.h>
3
4namespace DB
5{
6
7namespace ErrorCodes
8{
9extern const int ILLEGAL_TYPE_OF_ARGUMENT;
10extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
11}
12
13namespace
14{
15
16struct AggregateFunctionCombinatorArgMinArgMaxData
17{
18private:
19SingleValueDataBaseMemoryBlock v_data;
20
21public:
22explicit AggregateFunctionCombinatorArgMinArgMaxData(TypeIndex value_type) { generateSingleValueFromTypeIndex(value_type, v_data); }
23
24~AggregateFunctionCombinatorArgMinArgMaxData() { data().~SingleValueDataBase(); }
25
26SingleValueDataBase & data() { return v_data.get(); }
27const SingleValueDataBase & data() const { return v_data.get(); }
28};
29
30template <bool isMin>
31class AggregateFunctionCombinatorArgMinArgMax final : public IAggregateFunctionHelper<AggregateFunctionCombinatorArgMinArgMax<isMin>>
32{
33using Key = AggregateFunctionCombinatorArgMinArgMaxData;
34
35private:
36AggregateFunctionPtr nested_function;
37SerializationPtr serialization;
38const size_t key_col;
39const size_t key_offset;
40const TypeIndex key_type_index;
41
42AggregateFunctionCombinatorArgMinArgMaxData & data(AggregateDataPtr __restrict place) const /// NOLINT
43{
44return *reinterpret_cast<Key *>(place + key_offset);
45}
46const AggregateFunctionCombinatorArgMinArgMaxData & data(ConstAggregateDataPtr __restrict place) const
47{
48return *reinterpret_cast<const Key *>(place + key_offset);
49}
50
51public:
52AggregateFunctionCombinatorArgMinArgMax(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
53: IAggregateFunctionHelper<AggregateFunctionCombinatorArgMinArgMax<isMin>>{arguments, params, nested_function_->getResultType()}
54, nested_function{nested_function_}
55, serialization(arguments.back()->getDefaultSerialization())
56, key_col{arguments.size() - 1}
57, key_offset{((nested_function->sizeOfData() + alignof(Key) - 1) / alignof(Key)) * alignof(Key)}
58, key_type_index(WhichDataType(arguments[key_col]).idx)
59{
60if (!arguments[key_col]->isComparable())
61throw Exception(
62ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
63"Illegal type {} for combinator {} because the values of that data type are not comparable",
64arguments[key_col]->getName(),
65getName());
66}
67
68String getName() const override
69{
70if constexpr (isMin)
71return "ArgMin";
72else
73return "ArgMax";
74}
75
76bool isState() const override { return nested_function->isState(); }
77
78bool isVersioned() const override { return nested_function->isVersioned(); }
79
80size_t getVersionFromRevision(size_t revision) const override { return nested_function->getVersionFromRevision(revision); }
81
82size_t getDefaultVersion() const override { return nested_function->getDefaultVersion(); }
83
84bool allocatesMemoryInArena() const override
85{
86return nested_function->allocatesMemoryInArena() || singleValueTypeAllocatesMemoryInArena(key_type_index);
87}
88
89bool hasTrivialDestructor() const override
90{
91return nested_function->hasTrivialDestructor() && /*false*/ std::is_trivially_destructible_v<SingleValueDataBase>;
92}
93
94size_t sizeOfData() const override { return key_offset + sizeof(Key); }
95
96size_t alignOfData() const override { return std::max(nested_function->alignOfData(), alignof(SingleValueDataBaseMemoryBlock)); }
97
98void create(AggregateDataPtr __restrict place) const override
99{
100nested_function->create(place);
101new (place + key_offset) Key(key_type_index);
102}
103
104void destroy(AggregateDataPtr __restrict place) const noexcept override
105{
106data(place).~Key();
107nested_function->destroy(place);
108}
109
110void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override
111{
112data(place).~Key();
113nested_function->destroyUpToState(place);
114}
115
116void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
117{
118if ((isMin && data(place).data().setIfSmaller(*columns[key_col], row_num, arena))
119|| (!isMin && data(place).data().setIfGreater(*columns[key_col], row_num, arena)))
120{
121nested_function->destroy(place);
122nested_function->create(place);
123nested_function->add(place, columns, row_num, arena);
124}
125else if (data(place).data().isEqualTo(*columns[key_col], row_num))
126{
127nested_function->add(place, columns, row_num, arena);
128}
129}
130
131void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
132{
133if ((isMin && data(place).data().setIfSmaller(data(rhs).data(), arena))
134|| (!isMin && data(place).data().setIfGreater(data(rhs).data(), arena)))
135{
136nested_function->destroy(place);
137nested_function->create(place);
138nested_function->merge(place, rhs, arena);
139}
140else if (data(place).data().isEqualTo(data(rhs).data()))
141{
142nested_function->merge(place, rhs, arena);
143}
144}
145
146void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
147{
148nested_function->serialize(place, buf, version);
149data(place).data().write(buf, *serialization);
150}
151
152void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
153{
154nested_function->deserialize(place, buf, version, arena);
155data(place).data().read(buf, *serialization, arena);
156}
157
158void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
159{
160nested_function->insertResultInto(place, to, arena);
161}
162
163void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
164{
165nested_function->insertMergeResultInto(place, to, arena);
166}
167
168AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
169};
170
171template <bool isMin>
172class CombinatorArgMinArgMax final : public IAggregateFunctionCombinator
173{
174public:
175String getName() const override
176{
177if constexpr (isMin)
178return "ArgMin";
179else
180return "ArgMax";
181}
182
183DataTypes transformArguments(const DataTypes & arguments) const override
184{
185if (arguments.empty())
186throw Exception(
187ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
188"Incorrect number of arguments for aggregate function with {} suffix",
189getName());
190
191return DataTypes(arguments.begin(), arguments.end() - 1);
192}
193
194AggregateFunctionPtr transformAggregateFunction(
195const AggregateFunctionPtr & nested_function,
196const AggregateFunctionProperties &,
197const DataTypes & arguments,
198const Array & params) const override
199{
200return std::make_shared<AggregateFunctionCombinatorArgMinArgMax<isMin>>(nested_function, arguments, params);
201}
202};
203
204}
205
206void registerAggregateFunctionCombinatorsArgMinArgMax(AggregateFunctionCombinatorFactory & factory)
207{
208factory.registerCombinator(std::make_shared<CombinatorArgMinArgMax<true>>());
209factory.registerCombinator(std::make_shared<CombinatorArgMinArgMax<false>>());
210}
211
212}
213