ClickHouse

Форк
0
/
AggregateFunctionGroupBitmap.cpp 
285 строк · 11.1 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/FactoryHelpers.h>
3
#include <DataTypes/DataTypeAggregateFunction.h>
4

5
#include <AggregateFunctions/IAggregateFunction.h>
6
#include <Columns/ColumnAggregateFunction.h>
7
#include <Columns/ColumnVector.h>
8
#include <DataTypes/DataTypesNumber.h>
9
#include <Common/assert_cast.h>
10

11
// TODO include this last because of a broken roaring header. See the comment inside.
12
#include <AggregateFunctions/AggregateFunctionGroupBitmapData.h>
13

14

15
namespace DB
16
{
17
struct Settings;
18

19
namespace ErrorCodes
20
{
21
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
22
}
23

24
namespace
25
{
26

27
/// Counts bitmap operation on numbers.
28
template <typename T, typename Data>
29
class AggregateFunctionBitmap final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>
30
{
31
public:
32
    explicit AggregateFunctionBitmap(const DataTypePtr & type)
33
        : IAggregateFunctionDataHelper<Data, AggregateFunctionBitmap<T, Data>>({type}, {}, createResultType())
34
    {
35
    }
36

37
    String getName() const override { return Data::name(); }
38

39
    static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
40

41
    bool allocatesMemoryInArena() const override { return false; }
42

43
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
44
    {
45
        this->data(place).roaring_bitmap_with_small_set.add(assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
46
    }
47

48
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
49
    {
50
        this->data(place).roaring_bitmap_with_small_set.merge(this->data(rhs).roaring_bitmap_with_small_set);
51
    }
52

53
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
54
    {
55
        this->data(place).roaring_bitmap_with_small_set.write(buf);
56
    }
57

58
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
59
    {
60
        this->data(place).roaring_bitmap_with_small_set.read(buf);
61
    }
62

63
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
64
    {
65
        assert_cast<ColumnVector<T> &>(to).getData().push_back(
66
            static_cast<T>(this->data(place).roaring_bitmap_with_small_set.size()));
67
    }
68
};
69

70

71
/// This aggregate function takes the states of AggregateFunctionBitmap as its argument.
72
template <typename T, typename Data, typename Policy>
73
class AggregateFunctionBitmapL2 final : public IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>
74
{
75
private:
76
    static constexpr size_t STATE_VERSION_1_MIN_REVISION = 54455;
77
public:
78
    explicit AggregateFunctionBitmapL2(const DataTypePtr & type)
79
        : IAggregateFunctionDataHelper<Data, AggregateFunctionBitmapL2<T, Data, Policy>>({type}, {}, createResultType())
80
    {
81
    }
82

83
    String getName() const override { return Policy::name; }
84

85
    static DataTypePtr createResultType() { return std::make_shared<DataTypeNumber<T>>(); }
86

87
    bool allocatesMemoryInArena() const override { return false; }
88

89
    DataTypePtr getStateType() const override
90
    {
91
        return this->argument_types.at(0);
92
    }
93

94
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
95
    {
96
        Data & data_lhs = this->data(place);
97
        const Data & data_rhs = this->data(assert_cast<const ColumnAggregateFunction &>(*columns[0]).getData()[row_num]);
98
        if (!data_lhs.init)
99
        {
100
            data_lhs.init = true;
101
            data_lhs.roaring_bitmap_with_small_set.merge(data_rhs.roaring_bitmap_with_small_set);
102
        }
103
        else
104
        {
105
            Policy::apply(data_lhs, data_rhs);
106
        }
107
    }
108

109
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
110
    {
111
        Data & data_lhs = this->data(place);
112
        const Data & data_rhs = this->data(rhs);
113

114
        if (!data_rhs.init)
115
            return;
116

117
        if (!data_lhs.init)
118
        {
119
            data_lhs.init = true;
120
            data_lhs.roaring_bitmap_with_small_set.merge(data_rhs.roaring_bitmap_with_small_set);
121
        }
122
        else
123
        {
124
            Policy::apply(data_lhs, data_rhs);
125
        }
126
    }
127

128
    bool isVersioned() const override { return true; }
129

130
    size_t getDefaultVersion() const override { return 1; }
131

132
    size_t getVersionFromRevision(size_t revision) const override
133
    {
134
        if (revision >= STATE_VERSION_1_MIN_REVISION)
135
            return 1;
136
        else
137
            return 0;
138
    }
139

140
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
141
    {
142
        if (!version)
143
            version = getDefaultVersion();
144

145
        if (*version >= 1)
146
            DB::writeBoolText(this->data(place).init, buf);
147

148
        this->data(place).roaring_bitmap_with_small_set.write(buf);
149
    }
150

151
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> version, Arena *) const override
152
    {
153
        if (!version)
154
            version = getDefaultVersion();
155

156
        if (*version >= 1)
157
            DB::readBoolText(this->data(place).init, buf);
158
        this->data(place).roaring_bitmap_with_small_set.read(buf);
159
    }
160

161
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
162
    {
163
        assert_cast<ColumnVector<T> &>(to).getData().push_back(
164
            static_cast<T>(this->data(place).roaring_bitmap_with_small_set.size()));
165
    }
166
};
167

168

169
template <typename Data>
170
class BitmapAndPolicy
171
{
172
public:
173
    static constexpr auto name = "groupBitmapAnd";
174
    static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_and(rhs.roaring_bitmap_with_small_set); }
175
};
176

177
template <typename Data>
178
class BitmapOrPolicy
179
{
180
public:
181
    static constexpr auto name = "groupBitmapOr";
182
    static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_or(rhs.roaring_bitmap_with_small_set); }
183
};
184

185
template <typename Data>
186
class BitmapXorPolicy
187
{
188
public:
189
    static constexpr auto name = "groupBitmapXor";
190
    static void apply(Data & lhs, const Data & rhs) { lhs.roaring_bitmap_with_small_set.rb_xor(rhs.roaring_bitmap_with_small_set); }
191
};
192

193
template <typename T, typename Data>
194
using AggregateFunctionBitmapL2And = AggregateFunctionBitmapL2<T, Data, BitmapAndPolicy<Data>>;
195

196
template <typename T, typename Data>
197
using AggregateFunctionBitmapL2Or = AggregateFunctionBitmapL2<T, Data, BitmapOrPolicy<Data>>;
198

199
template <typename T, typename Data>
200
using AggregateFunctionBitmapL2Xor = AggregateFunctionBitmapL2<T, Data, BitmapXorPolicy<Data>>;
201

202

203
template <template <typename, typename> class AggregateFunctionTemplate, template <typename> typename Data, typename... TArgs>
204
IAggregateFunction * createWithIntegerType(const IDataType & argument_type, TArgs &&... args)
205
{
206
    WhichDataType which(argument_type);
207
    if (which.idx == TypeIndex::UInt8) return new AggregateFunctionTemplate<UInt8, Data<UInt8>>(std::forward<TArgs>(args)...);
208
    if (which.idx == TypeIndex::UInt16) return new AggregateFunctionTemplate<UInt16, Data<UInt16>>(std::forward<TArgs>(args)...);
209
    if (which.idx == TypeIndex::UInt32) return new AggregateFunctionTemplate<UInt32, Data<UInt32>>(std::forward<TArgs>(args)...);
210
    if (which.idx == TypeIndex::UInt64) return new AggregateFunctionTemplate<UInt64, Data<UInt64>>(std::forward<TArgs>(args)...);
211
    if (which.idx == TypeIndex::Int8) return new AggregateFunctionTemplate<Int8, Data<Int8>>(std::forward<TArgs>(args)...);
212
    if (which.idx == TypeIndex::Int16) return new AggregateFunctionTemplate<Int16, Data<Int16>>(std::forward<TArgs>(args)...);
213
    if (which.idx == TypeIndex::Int32) return new AggregateFunctionTemplate<Int32, Data<Int32>>(std::forward<TArgs>(args)...);
214
    if (which.idx == TypeIndex::Int64) return new AggregateFunctionTemplate<Int64, Data<Int64>>(std::forward<TArgs>(args)...);
215
    return nullptr;
216
}
217

218
template <template <typename> typename Data>
219
AggregateFunctionPtr createAggregateFunctionBitmap(
220
    const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
221
{
222
    assertNoParameters(name, parameters);
223
    assertUnary(name, argument_types);
224

225
    if (!argument_types[0]->canBeUsedInBitOperations())
226
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
227
                        "The type {} of argument for aggregate function {} "
228
                        "is illegal, because it cannot be used in Bitmap operations",
229
                        argument_types[0]->getName(), name);
230

231
    AggregateFunctionPtr res(createWithIntegerType<AggregateFunctionBitmap, Data>(*argument_types[0], argument_types[0]));
232

233
    if (!res)
234
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
235
            argument_types[0]->getName(), name);
236

237
    return res;
238
}
239

240
// Additional aggregate functions to manipulate bitmaps.
241
template <template <typename, typename> typename AggregateFunctionTemplate>
242
AggregateFunctionPtr createAggregateFunctionBitmapL2(
243
    const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
244
{
245
    assertNoParameters(name, parameters);
246
    assertUnary(name, argument_types);
247

248
    DataTypePtr argument_type_ptr = argument_types[0];
249
    WhichDataType which(*argument_type_ptr);
250
    if (which.idx != TypeIndex::AggregateFunction)
251
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
252
            argument_types[0]->getName(), name);
253

254
    /// groupBitmap needs to know about the data type that was used to create bitmaps.
255
    /// We need to look inside the type of its argument to obtain it.
256
    const DataTypeAggregateFunction & datatype_aggfunc = dynamic_cast<const DataTypeAggregateFunction &>(*argument_type_ptr);
257
    AggregateFunctionPtr aggfunc = datatype_aggfunc.getFunction();
258

259
    if (aggfunc->getName() != AggregateFunctionGroupBitmapData<UInt8>::name())
260
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
261
            argument_types[0]->getName(), name);
262

263
    DataTypePtr nested_argument_type_ptr = aggfunc->getArgumentTypes()[0];
264

265
    AggregateFunctionPtr res(createWithIntegerType<AggregateFunctionTemplate, AggregateFunctionGroupBitmapData>(
266
        *nested_argument_type_ptr, argument_type_ptr));
267

268
    if (!res)
269
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}",
270
            argument_types[0]->getName(), name);
271

272
    return res;
273
}
274

275
}
276

277
void registerAggregateFunctionsBitmap(AggregateFunctionFactory & factory)
278
{
279
    factory.registerFunction("groupBitmap", createAggregateFunctionBitmap<AggregateFunctionGroupBitmapData>);
280
    factory.registerFunction("groupBitmapAnd", createAggregateFunctionBitmapL2<AggregateFunctionBitmapL2And>);
281
    factory.registerFunction("groupBitmapOr", createAggregateFunctionBitmapL2<AggregateFunctionBitmapL2Or>);
282
    factory.registerFunction("groupBitmapXor", createAggregateFunctionBitmapL2<AggregateFunctionBitmapL2Xor>);
283
}
284

285
}
286

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.