ClickHouse

Форк
0
/
AggregateFunctionCombinatorsArgMinArgMax.cpp 
212 строк · 7.4 Кб
1
#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
2
#include <AggregateFunctions/SingleValueData.h>
3

4
namespace DB
5
{
6

7
namespace ErrorCodes
8
{
9
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
10
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
11
}
12

13
namespace
14
{
15

16
struct AggregateFunctionCombinatorArgMinArgMaxData
17
{
18
private:
19
    SingleValueDataBaseMemoryBlock v_data;
20

21
public:
22
    explicit AggregateFunctionCombinatorArgMinArgMaxData(TypeIndex value_type) { generateSingleValueFromTypeIndex(value_type, v_data); }
23

24
    ~AggregateFunctionCombinatorArgMinArgMaxData() { data().~SingleValueDataBase(); }
25

26
    SingleValueDataBase & data() { return v_data.get(); }
27
    const SingleValueDataBase & data() const { return v_data.get(); }
28
};
29

30
template <bool isMin>
31
class AggregateFunctionCombinatorArgMinArgMax final : public IAggregateFunctionHelper<AggregateFunctionCombinatorArgMinArgMax<isMin>>
32
{
33
    using Key = AggregateFunctionCombinatorArgMinArgMaxData;
34

35
private:
36
    AggregateFunctionPtr nested_function;
37
    SerializationPtr serialization;
38
    const size_t key_col;
39
    const size_t key_offset;
40
    const TypeIndex key_type_index;
41

42
    AggregateFunctionCombinatorArgMinArgMaxData & data(AggregateDataPtr __restrict place) const /// NOLINT
43
    {
44
        return *reinterpret_cast<Key *>(place + key_offset);
45
    }
46
    const AggregateFunctionCombinatorArgMinArgMaxData & data(ConstAggregateDataPtr __restrict place) const
47
    {
48
        return *reinterpret_cast<const Key *>(place + key_offset);
49
    }
50

51
public:
52
    AggregateFunctionCombinatorArgMinArgMax(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
53
        : IAggregateFunctionHelper<AggregateFunctionCombinatorArgMinArgMax<isMin>>{arguments, params, nested_function_->getResultType()}
54
        , nested_function{nested_function_}
55
        , serialization(arguments.back()->getDefaultSerialization())
56
        , key_col{arguments.size() - 1}
57
        , key_offset{((nested_function->sizeOfData() + alignof(Key) - 1) / alignof(Key)) * alignof(Key)}
58
        , key_type_index(WhichDataType(arguments[key_col]).idx)
59
    {
60
        if (!arguments[key_col]->isComparable())
61
            throw Exception(
62
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
63
                "Illegal type {} for combinator {} because the values of that data type are not comparable",
64
                arguments[key_col]->getName(),
65
                getName());
66
    }
67

68
    String getName() const override
69
    {
70
        if constexpr (isMin)
71
            return "ArgMin";
72
        else
73
            return "ArgMax";
74
    }
75

76
    bool isState() const override { return nested_function->isState(); }
77

78
    bool isVersioned() const override { return nested_function->isVersioned(); }
79

80
    size_t getVersionFromRevision(size_t revision) const override { return nested_function->getVersionFromRevision(revision); }
81

82
    size_t getDefaultVersion() const override { return nested_function->getDefaultVersion(); }
83

84
    bool allocatesMemoryInArena() const override
85
    {
86
        return nested_function->allocatesMemoryInArena() || singleValueTypeAllocatesMemoryInArena(key_type_index);
87
    }
88

89
    bool hasTrivialDestructor() const override
90
    {
91
        return nested_function->hasTrivialDestructor() && /*false*/ std::is_trivially_destructible_v<SingleValueDataBase>;
92
    }
93

94
    size_t sizeOfData() const override { return key_offset + sizeof(Key); }
95

96
    size_t alignOfData() const override { return std::max(nested_function->alignOfData(), alignof(SingleValueDataBaseMemoryBlock)); }
97

98
    void create(AggregateDataPtr __restrict place) const override
99
    {
100
        nested_function->create(place);
101
        new (place + key_offset) Key(key_type_index);
102
    }
103

104
    void destroy(AggregateDataPtr __restrict place) const noexcept override
105
    {
106
        data(place).~Key();
107
        nested_function->destroy(place);
108
    }
109

110
    void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override
111
    {
112
        data(place).~Key();
113
        nested_function->destroyUpToState(place);
114
    }
115

116
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
117
    {
118
        if ((isMin && data(place).data().setIfSmaller(*columns[key_col], row_num, arena))
119
            || (!isMin && data(place).data().setIfGreater(*columns[key_col], row_num, arena)))
120
        {
121
            nested_function->destroy(place);
122
            nested_function->create(place);
123
            nested_function->add(place, columns, row_num, arena);
124
        }
125
        else if (data(place).data().isEqualTo(*columns[key_col], row_num))
126
        {
127
            nested_function->add(place, columns, row_num, arena);
128
        }
129
    }
130

131
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
132
    {
133
        if ((isMin && data(place).data().setIfSmaller(data(rhs).data(), arena))
134
            || (!isMin && data(place).data().setIfGreater(data(rhs).data(), arena)))
135
        {
136
            nested_function->destroy(place);
137
            nested_function->create(place);
138
            nested_function->merge(place, rhs, arena);
139
        }
140
        else if (data(place).data().isEqualTo(data(rhs).data()))
141
        {
142
            nested_function->merge(place, rhs, arena);
143
        }
144
    }
145

146
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
147
    {
148
        nested_function->serialize(place, buf, version);
149
        data(place).data().write(buf, *serialization);
150
    }
151

152
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override
153
    {
154
        nested_function->deserialize(place, buf, version, arena);
155
        data(place).data().read(buf, *serialization, arena);
156
    }
157

158
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
159
    {
160
        nested_function->insertResultInto(place, to, arena);
161
    }
162

163
    void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
164
    {
165
        nested_function->insertMergeResultInto(place, to, arena);
166
    }
167

168
    AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
169
};
170

171
template <bool isMin>
172
class CombinatorArgMinArgMax final : public IAggregateFunctionCombinator
173
{
174
public:
175
    String getName() const override
176
    {
177
        if constexpr (isMin)
178
            return "ArgMin";
179
        else
180
            return "ArgMax";
181
    }
182

183
    DataTypes transformArguments(const DataTypes & arguments) const override
184
    {
185
        if (arguments.empty())
186
            throw Exception(
187
                ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
188
                "Incorrect number of arguments for aggregate function with {} suffix",
189
                getName());
190

191
        return DataTypes(arguments.begin(), arguments.end() - 1);
192
    }
193

194
    AggregateFunctionPtr transformAggregateFunction(
195
        const AggregateFunctionPtr & nested_function,
196
        const AggregateFunctionProperties &,
197
        const DataTypes & arguments,
198
        const Array & params) const override
199
    {
200
        return std::make_shared<AggregateFunctionCombinatorArgMinArgMax<isMin>>(nested_function, arguments, params);
201
    }
202
};
203

204
}
205

206
void registerAggregateFunctionCombinatorsArgMinArgMax(AggregateFunctionCombinatorFactory & factory)
207
{
208
    factory.registerCombinator(std::make_shared<CombinatorArgMinArgMax<true>>());
209
    factory.registerCombinator(std::make_shared<CombinatorArgMinArgMax<false>>());
210
}
211

212
}
213

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.