ClickHouse

Форк
0
/
AggregateFunctionWindowFunnel.cpp 
345 строк · 11.9 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/Helpers.h>
3
#include <Core/Settings.h>
4
#include <DataTypes/DataTypeDate.h>
5
#include <DataTypes/DataTypeDateTime.h>
6

7
#include <unordered_set>
8
#include <Columns/ColumnsNumber.h>
9
#include <DataTypes/DataTypesNumber.h>
10
#include <IO/ReadHelpers.h>
11
#include <IO/WriteHelpers.h>
12
#include <Common/assert_cast.h>
13

14

15
namespace DB
16
{
17
struct Settings;
18

19
namespace ErrorCodes
20
{
21
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
22
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
23
    extern const int TOO_LARGE_ARRAY_SIZE;
24
    extern const int BAD_ARGUMENTS;
25
}
26

27
namespace
28
{
29

30
constexpr size_t max_events = 32;
31

32
template <typename T>
33
struct AggregateFunctionWindowFunnelData
34
{
35
    using TimestampEvent = std::pair<T, UInt8>;
36
    using TimestampEvents = PODArrayWithStackMemory<TimestampEvent, 64>;
37

38
    bool sorted = true;
39
    TimestampEvents events_list;
40

41
    size_t size() const
42
    {
43
        return events_list.size();
44
    }
45

46
    void add(T timestamp, UInt8 event)
47
    {
48
        /// Since most events should have already been sorted by timestamp.
49
        if (sorted && events_list.size() > 0)
50
        {
51
            if (events_list.back().first == timestamp)
52
                sorted = events_list.back().second <= event;
53
            else
54
                sorted = events_list.back().first <= timestamp;
55
        }
56
        events_list.emplace_back(timestamp, event);
57
    }
58

59
    void merge(const AggregateFunctionWindowFunnelData & other)
60
    {
61
        if (other.events_list.empty())
62
            return;
63

64
        const auto size = events_list.size();
65

66
        events_list.insert(std::begin(other.events_list), std::end(other.events_list));
67

68
        /// either sort whole container or do so partially merging ranges afterwards
69
        if (!sorted && !other.sorted)
70
            std::stable_sort(std::begin(events_list), std::end(events_list));
71
        else
72
        {
73
            const auto begin = std::begin(events_list);
74
            const auto middle = std::next(begin, size);
75
            const auto end = std::end(events_list);
76

77
            if (!sorted)
78
                std::stable_sort(begin, middle);
79

80
            if (!other.sorted)
81
                std::stable_sort(middle, end);
82

83
            std::inplace_merge(begin, middle, end);
84
        }
85

86
        sorted = true;
87
    }
88

89
    void sort()
90
    {
91
        if (!sorted)
92
        {
93
            std::stable_sort(std::begin(events_list), std::end(events_list));
94
            sorted = true;
95
        }
96
    }
97

98
    void serialize(WriteBuffer & buf) const
99
    {
100
        writeBinary(sorted, buf);
101
        writeBinary(events_list.size(), buf);
102

103
        for (const auto & events : events_list)
104
        {
105
            writeBinary(events.first, buf);
106
            writeBinary(events.second, buf);
107
        }
108
    }
109

110
    void deserialize(ReadBuffer & buf)
111
    {
112
        readBinary(sorted, buf);
113

114
        size_t size;
115
        readBinary(size, buf);
116

117
        if (size > 100'000'000) /// The constant is arbitrary
118
            throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large size of the state of windowFunnel");
119

120
        events_list.clear();
121
        events_list.reserve(size);
122

123
        T timestamp;
124
        UInt8 event;
125

126
        for (size_t i = 0; i < size; ++i)
127
        {
128
            readBinary(timestamp, buf);
129
            readBinary(event, buf);
130
            events_list.emplace_back(timestamp, event);
131
        }
132
    }
133
};
134

135
/** Calculates the max event level in a sliding window.
136
  * The max size of events is 32, that's enough for funnel analytics
137
  *
138
  * Usage:
139
  * - windowFunnel(window)(timestamp, cond1, cond2, cond3, ....)
140
  */
141
template <typename T, typename Data>
142
class AggregateFunctionWindowFunnel final
143
    : public IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>
144
{
145
private:
146
    UInt64 window;
147
    UInt8 events_size;
148
    /// When the 'strict_deduplication' is set, it applies conditions only for the not repeating values.
149
    bool strict_deduplication;
150

151
    /// When the 'strict_order' is set, it doesn't allow interventions of other events.
152
    /// In the case of 'A->B->D->C', it stops finding 'A->B->C' at the 'D' and the max event level is 2.
153
    bool strict_order;
154

155
    /// Applies conditions only to events with strictly increasing timestamps
156
    bool strict_increase;
157

158
    /// Loop through the entire events_list, update the event timestamp value
159
    /// The level path must be 1---2---3---...---check_events_size, find the max event level that satisfied the path in the sliding window.
160
    /// If found, returns the max event level, else return 0.
161
    /// The algorithm works in O(n) time, but the overall function works in O(n * log(n)) due to sorting.
162
    UInt8 getEventLevel(Data & data) const
163
    {
164
        if (data.size() == 0)
165
            return 0;
166
        if (!strict_order && events_size == 1)
167
            return 1;
168

169
        data.sort();
170

171
        /// events_timestamp stores the timestamp of the first and previous i-th level event happen within time window
172
        std::vector<std::optional<std::pair<UInt64, UInt64>>> events_timestamp(events_size);
173
        bool first_event = false;
174
        for (size_t i = 0; i < data.events_list.size(); ++i)
175
        {
176
            const T & timestamp = data.events_list[i].first;
177
            const auto & event_idx = data.events_list[i].second - 1;
178
            if (strict_order && event_idx == -1)
179
            {
180
                if (first_event)
181
                    break;
182
                else
183
                    continue;
184
            }
185
            else if (event_idx == 0)
186
            {
187
                events_timestamp[0] = std::make_pair(timestamp, timestamp);
188
                first_event = true;
189
            }
190
            else if (strict_deduplication && events_timestamp[event_idx].has_value())
191
            {
192
                return data.events_list[i - 1].second;
193
            }
194
            else if (strict_order && first_event && !events_timestamp[event_idx - 1].has_value())
195
            {
196
                for (size_t event = 0; event < events_timestamp.size(); ++event)
197
                {
198
                    if (!events_timestamp[event].has_value())
199
                        return event;
200
                }
201
            }
202
            else if (events_timestamp[event_idx - 1].has_value())
203
            {
204
                auto first_timestamp = events_timestamp[event_idx - 1]->first;
205
                bool time_matched = timestamp <= first_timestamp + window;
206
                if (strict_increase)
207
                    time_matched = time_matched && events_timestamp[event_idx - 1]->second < timestamp;
208
                if (time_matched)
209
                {
210
                    events_timestamp[event_idx] = std::make_pair(first_timestamp, timestamp);
211
                    if (event_idx + 1 == events_size)
212
                        return events_size;
213
                }
214
            }
215
        }
216

217
        for (size_t event = events_timestamp.size(); event > 0; --event)
218
        {
219
            if (events_timestamp[event - 1].has_value())
220
                return event;
221
        }
222
        return 0;
223
    }
224

225
public:
226
    String getName() const override
227
    {
228
        return "windowFunnel";
229
    }
230

231
    AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
232
        : IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params, std::make_shared<DataTypeUInt8>())
233
    {
234
        events_size = arguments.size() - 1;
235
        window = params.at(0).safeGet<UInt64>();
236

237
        strict_deduplication = false;
238
        strict_order = false;
239
        strict_increase = false;
240
        for (size_t i = 1; i < params.size(); ++i)
241
        {
242
            String option = params.at(i).safeGet<String>();
243
            if (option == "strict_deduplication")
244
                strict_deduplication = true;
245
            else if (option == "strict_order")
246
                strict_order = true;
247
            else if (option == "strict_increase")
248
                strict_increase = true;
249
            else if (option == "strict")
250
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "strict is replaced with strict_deduplication in Aggregate function {}", getName());
251
            else
252
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} doesn't support a parameter: {}", getName(), option);
253
        }
254
    }
255

256
    bool allocatesMemoryInArena() const override { return false; }
257

258
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
259
    {
260
        bool has_event = false;
261
        const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
262
        /// reverse iteration and stable sorting are needed for events that are qualified by more than one condition.
263
        for (auto i = events_size; i > 0; --i)
264
        {
265
            auto event = assert_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
266
            if (event)
267
            {
268
                this->data(place).add(timestamp, i);
269
                has_event = true;
270
            }
271
        }
272

273
        if (strict_order && !has_event)
274
            this->data(place).add(timestamp, 0);
275
    }
276

277
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
278
    {
279
        this->data(place).merge(this->data(rhs));
280
    }
281

282
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
283
    {
284
        this->data(place).serialize(buf);
285
    }
286

287
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version  */, Arena *) const override
288
    {
289
        this->data(place).deserialize(buf);
290
    }
291

292
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
293
    {
294
        assert_cast<ColumnUInt8 &>(to).getData().push_back(getEventLevel(this->data(place)));
295
    }
296
};
297

298

299
template <template <typename> class Data>
300
AggregateFunctionPtr
301
createAggregateFunctionWindowFunnel(const std::string & name, const DataTypes & arguments, const Array & params, const Settings *)
302
{
303
    if (params.empty())
304
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
305
                        "Aggregate function {} requires at least one parameter: <window>, [option, [option, ...]]",
306
                        name);
307

308
    if (arguments.size() < 2)
309
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
310
                        "Aggregate function {} requires one timestamp argument and at least one event condition.", name);
311

312
    if (arguments.size() > max_events + 1)
313
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Too many event arguments for aggregate function {}", name);
314

315
    for (const auto i : collections::range(1, arguments.size()))
316
    {
317
        const auto * cond_arg = arguments[i].get();
318
        if (!isUInt8(cond_arg))
319
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
320
                            "Illegal type {} of argument {} of aggregate function {}, must be UInt8",
321
                            cond_arg->getName(), toString(i + 1), name);
322
    }
323

324
    AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionWindowFunnel, Data>(*arguments[0], arguments, params));
325
    WhichDataType which(arguments.front().get());
326
    if (res)
327
        return res;
328
    else if (which.isDate())
329
        return std::make_shared<AggregateFunctionWindowFunnel<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(arguments, params);
330
    else if (which.isDateTime())
331
        return std::make_shared<AggregateFunctionWindowFunnel<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(arguments, params);
332

333
    throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
334
                    "Illegal type {} of first argument of aggregate function {}, must "
335
                    "be Unsigned Number, Date, DateTime", arguments.front().get()->getName(), name);
336
}
337

338
}
339

340
void registerAggregateFunctionWindowFunnel(AggregateFunctionFactory & factory)
341
{
342
    factory.registerFunction("windowFunnel", createAggregateFunctionWindowFunnel<AggregateFunctionWindowFunnelData>);
343
}
344

345
}
346

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.