ClickHouse
345 строк · 11.9 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/Helpers.h>
3#include <Core/Settings.h>
4#include <DataTypes/DataTypeDate.h>
5#include <DataTypes/DataTypeDateTime.h>
6
7#include <unordered_set>
8#include <Columns/ColumnsNumber.h>
9#include <DataTypes/DataTypesNumber.h>
10#include <IO/ReadHelpers.h>
11#include <IO/WriteHelpers.h>
12#include <Common/assert_cast.h>
13
14
15namespace DB
16{
17struct Settings;
18
19namespace ErrorCodes
20{
21extern const int ILLEGAL_TYPE_OF_ARGUMENT;
22extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
23extern const int TOO_LARGE_ARRAY_SIZE;
24extern const int BAD_ARGUMENTS;
25}
26
27namespace
28{
29
30constexpr size_t max_events = 32;
31
32template <typename T>
33struct AggregateFunctionWindowFunnelData
34{
35using TimestampEvent = std::pair<T, UInt8>;
36using TimestampEvents = PODArrayWithStackMemory<TimestampEvent, 64>;
37
38bool sorted = true;
39TimestampEvents events_list;
40
41size_t size() const
42{
43return events_list.size();
44}
45
46void add(T timestamp, UInt8 event)
47{
48/// Since most events should have already been sorted by timestamp.
49if (sorted && events_list.size() > 0)
50{
51if (events_list.back().first == timestamp)
52sorted = events_list.back().second <= event;
53else
54sorted = events_list.back().first <= timestamp;
55}
56events_list.emplace_back(timestamp, event);
57}
58
59void merge(const AggregateFunctionWindowFunnelData & other)
60{
61if (other.events_list.empty())
62return;
63
64const auto size = events_list.size();
65
66events_list.insert(std::begin(other.events_list), std::end(other.events_list));
67
68/// either sort whole container or do so partially merging ranges afterwards
69if (!sorted && !other.sorted)
70std::stable_sort(std::begin(events_list), std::end(events_list));
71else
72{
73const auto begin = std::begin(events_list);
74const auto middle = std::next(begin, size);
75const auto end = std::end(events_list);
76
77if (!sorted)
78std::stable_sort(begin, middle);
79
80if (!other.sorted)
81std::stable_sort(middle, end);
82
83std::inplace_merge(begin, middle, end);
84}
85
86sorted = true;
87}
88
89void sort()
90{
91if (!sorted)
92{
93std::stable_sort(std::begin(events_list), std::end(events_list));
94sorted = true;
95}
96}
97
98void serialize(WriteBuffer & buf) const
99{
100writeBinary(sorted, buf);
101writeBinary(events_list.size(), buf);
102
103for (const auto & events : events_list)
104{
105writeBinary(events.first, buf);
106writeBinary(events.second, buf);
107}
108}
109
110void deserialize(ReadBuffer & buf)
111{
112readBinary(sorted, buf);
113
114size_t size;
115readBinary(size, buf);
116
117if (size > 100'000'000) /// The constant is arbitrary
118throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large size of the state of windowFunnel");
119
120events_list.clear();
121events_list.reserve(size);
122
123T timestamp;
124UInt8 event;
125
126for (size_t i = 0; i < size; ++i)
127{
128readBinary(timestamp, buf);
129readBinary(event, buf);
130events_list.emplace_back(timestamp, event);
131}
132}
133};
134
135/** Calculates the max event level in a sliding window.
136* The max size of events is 32, that's enough for funnel analytics
137*
138* Usage:
139* - windowFunnel(window)(timestamp, cond1, cond2, cond3, ....)
140*/
141template <typename T, typename Data>
142class AggregateFunctionWindowFunnel final
143: public IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>
144{
145private:
146UInt64 window;
147UInt8 events_size;
148/// When the 'strict_deduplication' is set, it applies conditions only for the not repeating values.
149bool strict_deduplication;
150
151/// When the 'strict_order' is set, it doesn't allow interventions of other events.
152/// In the case of 'A->B->D->C', it stops finding 'A->B->C' at the 'D' and the max event level is 2.
153bool strict_order;
154
155/// Applies conditions only to events with strictly increasing timestamps
156bool strict_increase;
157
158/// Loop through the entire events_list, update the event timestamp value
159/// The level path must be 1---2---3---...---check_events_size, find the max event level that satisfied the path in the sliding window.
160/// If found, returns the max event level, else return 0.
161/// The algorithm works in O(n) time, but the overall function works in O(n * log(n)) due to sorting.
162UInt8 getEventLevel(Data & data) const
163{
164if (data.size() == 0)
165return 0;
166if (!strict_order && events_size == 1)
167return 1;
168
169data.sort();
170
171/// events_timestamp stores the timestamp of the first and previous i-th level event happen within time window
172std::vector<std::optional<std::pair<UInt64, UInt64>>> events_timestamp(events_size);
173bool first_event = false;
174for (size_t i = 0; i < data.events_list.size(); ++i)
175{
176const T & timestamp = data.events_list[i].first;
177const auto & event_idx = data.events_list[i].second - 1;
178if (strict_order && event_idx == -1)
179{
180if (first_event)
181break;
182else
183continue;
184}
185else if (event_idx == 0)
186{
187events_timestamp[0] = std::make_pair(timestamp, timestamp);
188first_event = true;
189}
190else if (strict_deduplication && events_timestamp[event_idx].has_value())
191{
192return data.events_list[i - 1].second;
193}
194else if (strict_order && first_event && !events_timestamp[event_idx - 1].has_value())
195{
196for (size_t event = 0; event < events_timestamp.size(); ++event)
197{
198if (!events_timestamp[event].has_value())
199return event;
200}
201}
202else if (events_timestamp[event_idx - 1].has_value())
203{
204auto first_timestamp = events_timestamp[event_idx - 1]->first;
205bool time_matched = timestamp <= first_timestamp + window;
206if (strict_increase)
207time_matched = time_matched && events_timestamp[event_idx - 1]->second < timestamp;
208if (time_matched)
209{
210events_timestamp[event_idx] = std::make_pair(first_timestamp, timestamp);
211if (event_idx + 1 == events_size)
212return events_size;
213}
214}
215}
216
217for (size_t event = events_timestamp.size(); event > 0; --event)
218{
219if (events_timestamp[event - 1].has_value())
220return event;
221}
222return 0;
223}
224
225public:
226String getName() const override
227{
228return "windowFunnel";
229}
230
231AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
232: IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params, std::make_shared<DataTypeUInt8>())
233{
234events_size = arguments.size() - 1;
235window = params.at(0).safeGet<UInt64>();
236
237strict_deduplication = false;
238strict_order = false;
239strict_increase = false;
240for (size_t i = 1; i < params.size(); ++i)
241{
242String option = params.at(i).safeGet<String>();
243if (option == "strict_deduplication")
244strict_deduplication = true;
245else if (option == "strict_order")
246strict_order = true;
247else if (option == "strict_increase")
248strict_increase = true;
249else if (option == "strict")
250throw Exception(ErrorCodes::BAD_ARGUMENTS, "strict is replaced with strict_deduplication in Aggregate function {}", getName());
251else
252throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} doesn't support a parameter: {}", getName(), option);
253}
254}
255
256bool allocatesMemoryInArena() const override { return false; }
257
258void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
259{
260bool has_event = false;
261const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
262/// reverse iteration and stable sorting are needed for events that are qualified by more than one condition.
263for (auto i = events_size; i > 0; --i)
264{
265auto event = assert_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
266if (event)
267{
268this->data(place).add(timestamp, i);
269has_event = true;
270}
271}
272
273if (strict_order && !has_event)
274this->data(place).add(timestamp, 0);
275}
276
277void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
278{
279this->data(place).merge(this->data(rhs));
280}
281
282void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
283{
284this->data(place).serialize(buf);
285}
286
287void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
288{
289this->data(place).deserialize(buf);
290}
291
292void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
293{
294assert_cast<ColumnUInt8 &>(to).getData().push_back(getEventLevel(this->data(place)));
295}
296};
297
298
299template <template <typename> class Data>
300AggregateFunctionPtr
301createAggregateFunctionWindowFunnel(const std::string & name, const DataTypes & arguments, const Array & params, const Settings *)
302{
303if (params.empty())
304throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
305"Aggregate function {} requires at least one parameter: <window>, [option, [option, ...]]",
306name);
307
308if (arguments.size() < 2)
309throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
310"Aggregate function {} requires one timestamp argument and at least one event condition.", name);
311
312if (arguments.size() > max_events + 1)
313throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Too many event arguments for aggregate function {}", name);
314
315for (const auto i : collections::range(1, arguments.size()))
316{
317const auto * cond_arg = arguments[i].get();
318if (!isUInt8(cond_arg))
319throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
320"Illegal type {} of argument {} of aggregate function {}, must be UInt8",
321cond_arg->getName(), toString(i + 1), name);
322}
323
324AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionWindowFunnel, Data>(*arguments[0], arguments, params));
325WhichDataType which(arguments.front().get());
326if (res)
327return res;
328else if (which.isDate())
329return std::make_shared<AggregateFunctionWindowFunnel<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(arguments, params);
330else if (which.isDateTime())
331return std::make_shared<AggregateFunctionWindowFunnel<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(arguments, params);
332
333throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
334"Illegal type {} of first argument of aggregate function {}, must "
335"be Unsigned Number, Date, DateTime", arguments.front().get()->getName(), name);
336}
337
338}
339
340void registerAggregateFunctionWindowFunnel(AggregateFunctionFactory & factory)
341{
342factory.registerFunction("windowFunnel", createAggregateFunctionWindowFunnel<AggregateFunctionWindowFunnelData>);
343}
344
345}
346