ClickHouse

Форк
0
/
AggregateFunctionRetention.cpp 
168 строк · 4.9 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/IAggregateFunction.h>
3
#include <AggregateFunctions/FactoryHelpers.h>
4
#include <Columns/ColumnArray.h>
5
#include <Common/assert_cast.h>
6
#include <DataTypes/DataTypesNumber.h>
7
#include <DataTypes/DataTypeArray.h>
8
#include <IO/ReadHelpers.h>
9
#include <IO/WriteHelpers.h>
10
#include <base/range.h>
11

12
#include <bitset>
13
#include <unordered_set>
14

15

16
namespace DB
17
{
18

19
struct Settings;
20

21
namespace ErrorCodes
22
{
23
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
24
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
25
}
26

27
namespace
28
{
29

30
struct AggregateFunctionRetentionData
31
{
32
    static constexpr auto max_events = 32;
33

34
    using Events = std::bitset<max_events>;
35

36
    Events events;
37

38
    void add(UInt8 event)
39
    {
40
        events.set(event);
41
    }
42

43
    void merge(const AggregateFunctionRetentionData & other)
44
    {
45
        events |= other.events;
46
    }
47

48
    void serialize(WriteBuffer & buf) const
49
    {
50
        UInt32 event_value = static_cast<UInt32>(events.to_ulong());
51
        writeBinary(event_value, buf);
52
    }
53

54
    void deserialize(ReadBuffer & buf)
55
    {
56
        UInt32 event_value;
57
        readBinary(event_value, buf);
58
        events = event_value;
59
    }
60
};
61

62
/**
63
  * The max size of events is 32, that's enough for retention analytics
64
  *
65
  * Usage:
66
  * - retention(cond1, cond2, cond3, ....)
67
  * - returns [cond1_flag, cond1_flag && cond2_flag, cond1_flag && cond3_flag, ...]
68
  */
69
class AggregateFunctionRetention final
70
        : public IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>
71
{
72
private:
73
    UInt8 events_size;
74

75
public:
76
    String getName() const override
77
    {
78
        return "retention";
79
    }
80

81
    explicit AggregateFunctionRetention(const DataTypes & arguments)
82
        : IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}, std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>()))
83
    {
84
        for (const auto i : collections::range(0, arguments.size()))
85
        {
86
            const auto * cond_arg = arguments[i].get();
87
            if (!isUInt8(cond_arg))
88
                throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
89
                                "Illegal type {} of argument {} of aggregate function {}, must be UInt8",
90
                                cond_arg->getName(), i, getName());
91
        }
92

93
        events_size = static_cast<UInt8>(arguments.size());
94
    }
95

96
    bool allocatesMemoryInArena() const override { return false; }
97

98
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
99
    {
100
        for (const auto i : collections::range(0, events_size))
101
        {
102
            auto event = assert_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
103
            if (event)
104
            {
105
                data(place).add(i);
106
            }
107
        }
108
    }
109

110
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
111
    {
112
        data(place).merge(data(rhs));
113
    }
114

115
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
116
    {
117
        data(place).serialize(buf);
118
    }
119

120
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
121
    {
122
        data(place).deserialize(buf);
123
    }
124

125
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
126
    {
127
        auto & data_to = assert_cast<ColumnUInt8 &>(assert_cast<ColumnArray &>(to).getData()).getData();
128
        auto & offsets_to = assert_cast<ColumnArray &>(to).getOffsets();
129

130
        ColumnArray::Offset current_offset = data_to.size();
131
        data_to.resize(current_offset + events_size);
132

133
        const bool first_flag = data(place).events.test(0);
134
        data_to[current_offset] = first_flag;
135
        ++current_offset;
136

137
        for (size_t i = 1; i < events_size; ++i)
138
        {
139
            data_to[current_offset] = (first_flag && data(place).events.test(i));
140
            ++current_offset;
141
        }
142

143
        offsets_to.push_back(current_offset);
144
    }
145
};
146

147

148
AggregateFunctionPtr createAggregateFunctionRetention(const std::string & name, const DataTypes & arguments, const Array & params, const Settings *)
149
{
150
    assertNoParameters(name, params);
151

152
    if (arguments.size() < 2)
153
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Not enough event arguments for aggregate function {}", name);
154

155
    if (arguments.size() > AggregateFunctionRetentionData::max_events)
156
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Too many event arguments for aggregate function {}", name);
157

158
    return std::make_shared<AggregateFunctionRetention>(arguments);
159
}
160

161
}
162

163
void registerAggregateFunctionRetention(AggregateFunctionFactory & factory)
164
{
165
    factory.registerFunction("retention", createAggregateFunctionRetention);
166
}
167

168
}
169

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.