ClickHouse
168 строк · 4.9 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/IAggregateFunction.h>
3#include <AggregateFunctions/FactoryHelpers.h>
4#include <Columns/ColumnArray.h>
5#include <Common/assert_cast.h>
6#include <DataTypes/DataTypesNumber.h>
7#include <DataTypes/DataTypeArray.h>
8#include <IO/ReadHelpers.h>
9#include <IO/WriteHelpers.h>
10#include <base/range.h>
11
12#include <bitset>
13#include <unordered_set>
14
15
16namespace DB
17{
18
19struct Settings;
20
21namespace ErrorCodes
22{
23extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
24extern const int ILLEGAL_TYPE_OF_ARGUMENT;
25}
26
27namespace
28{
29
30struct AggregateFunctionRetentionData
31{
32static constexpr auto max_events = 32;
33
34using Events = std::bitset<max_events>;
35
36Events events;
37
38void add(UInt8 event)
39{
40events.set(event);
41}
42
43void merge(const AggregateFunctionRetentionData & other)
44{
45events |= other.events;
46}
47
48void serialize(WriteBuffer & buf) const
49{
50UInt32 event_value = static_cast<UInt32>(events.to_ulong());
51writeBinary(event_value, buf);
52}
53
54void deserialize(ReadBuffer & buf)
55{
56UInt32 event_value;
57readBinary(event_value, buf);
58events = event_value;
59}
60};
61
62/**
63* The max size of events is 32, that's enough for retention analytics
64*
65* Usage:
66* - retention(cond1, cond2, cond3, ....)
67* - returns [cond1_flag, cond1_flag && cond2_flag, cond1_flag && cond3_flag, ...]
68*/
69class AggregateFunctionRetention final
70: public IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>
71{
72private:
73UInt8 events_size;
74
75public:
76String getName() const override
77{
78return "retention";
79}
80
81explicit AggregateFunctionRetention(const DataTypes & arguments)
82: IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}, std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>()))
83{
84for (const auto i : collections::range(0, arguments.size()))
85{
86const auto * cond_arg = arguments[i].get();
87if (!isUInt8(cond_arg))
88throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
89"Illegal type {} of argument {} of aggregate function {}, must be UInt8",
90cond_arg->getName(), i, getName());
91}
92
93events_size = static_cast<UInt8>(arguments.size());
94}
95
96bool allocatesMemoryInArena() const override { return false; }
97
98void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
99{
100for (const auto i : collections::range(0, events_size))
101{
102auto event = assert_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
103if (event)
104{
105data(place).add(i);
106}
107}
108}
109
110void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
111{
112data(place).merge(data(rhs));
113}
114
115void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
116{
117data(place).serialize(buf);
118}
119
120void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
121{
122data(place).deserialize(buf);
123}
124
125void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
126{
127auto & data_to = assert_cast<ColumnUInt8 &>(assert_cast<ColumnArray &>(to).getData()).getData();
128auto & offsets_to = assert_cast<ColumnArray &>(to).getOffsets();
129
130ColumnArray::Offset current_offset = data_to.size();
131data_to.resize(current_offset + events_size);
132
133const bool first_flag = data(place).events.test(0);
134data_to[current_offset] = first_flag;
135++current_offset;
136
137for (size_t i = 1; i < events_size; ++i)
138{
139data_to[current_offset] = (first_flag && data(place).events.test(i));
140++current_offset;
141}
142
143offsets_to.push_back(current_offset);
144}
145};
146
147
148AggregateFunctionPtr createAggregateFunctionRetention(const std::string & name, const DataTypes & arguments, const Array & params, const Settings *)
149{
150assertNoParameters(name, params);
151
152if (arguments.size() < 2)
153throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Not enough event arguments for aggregate function {}", name);
154
155if (arguments.size() > AggregateFunctionRetentionData::max_events)
156throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Too many event arguments for aggregate function {}", name);
157
158return std::make_shared<AggregateFunctionRetention>(arguments);
159}
160
161}
162
163void registerAggregateFunctionRetention(AggregateFunctionFactory & factory)
164{
165factory.registerFunction("retention", createAggregateFunctionRetention);
166}
167
168}
169