ClickHouse

Форк
0
/
runningAccumulate.cpp 
149 строк · 4.9 Кб
1
#include <AggregateFunctions/IAggregateFunction.h>
2
#include <Columns/ColumnAggregateFunction.h>
3
#include <DataTypes/DataTypeAggregateFunction.h>
4
#include <Functions/FunctionFactory.h>
5
#include <Functions/FunctionHelpers.h>
6
#include <Functions/IFunction.h>
7
#include <Common/AlignedBuffer.h>
8
#include <Common/Arena.h>
9
#include <Common/scope_guard_safe.h>
10

11

12
namespace DB
13
{
14
namespace ErrorCodes
15
{
16
    extern const int ILLEGAL_COLUMN;
17
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
18
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
19
}
20

21
namespace
22
{
23

24
/** runningAccumulate(agg_state) - takes the states of the aggregate function and returns a column with values,
25
  * are the result of the accumulation of these states for a set of columns lines, from the first to the current line.
26
  *
27
  * Quite unusual function.
28
  * Takes state of aggregate function (example runningAccumulate(uniqState(UserID))),
29
  *  and for each row of columns, return result of aggregate function on merge of states of all previous rows and current row.
30
  *
31
  * So, result of function depends on partition of data to columns and on order of data in columns.
32
  */
33
class FunctionRunningAccumulate : public IFunction
34
{
35
public:
36
    static constexpr auto name = "runningAccumulate";
37
    static FunctionPtr create(ContextPtr)
38
    {
39
        return std::make_shared<FunctionRunningAccumulate>();
40
    }
41

42
    String getName() const override
43
    {
44
        return name;
45
    }
46

47
    bool isStateful() const override
48
    {
49
        return true;
50
    }
51

52
    bool isVariadic() const override { return true; }
53

54
    size_t getNumberOfArguments() const override { return 0; }
55

56
    bool isDeterministic() const override
57
    {
58
        return false;
59
    }
60

61
    bool isDeterministicInScopeOfQuery() const override
62
    {
63
        return false;
64
    }
65

66
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
67

68
    DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
69
    {
70
        if (arguments.empty() || arguments.size() > 2)
71
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
72
                "Incorrect number of arguments of function {}. Must be 1 or 2.", getName());
73

74
        const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
75
        if (!type)
76
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
77
                            "Argument for function {} must have type AggregateFunction - state "
78
                            "of aggregate function.", getName());
79

80
        return type->getReturnType();
81
    }
82

83
    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
84
    {
85
        const ColumnAggregateFunction * column_with_states
86
            = typeid_cast<const ColumnAggregateFunction *>(&*arguments.at(0).column);
87

88
        if (!column_with_states)
89
            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}",
90
                    arguments.at(0).column->getName(), getName());
91

92
        ColumnPtr column_with_groups;
93

94
        if (arguments.size() == 2)
95
            column_with_groups = arguments[1].column;
96

97
        AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
98
        const IAggregateFunction & agg_func = *aggregate_function_ptr;
99

100
        AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData());
101

102
        /// Will pass empty arena if agg_func does not allocate memory in arena
103
        std::unique_ptr<Arena> arena = agg_func.allocatesMemoryInArena() ? std::make_unique<Arena>() : nullptr;
104

105
        auto result_column_ptr = agg_func.getResultType()->createColumn();
106
        IColumn & result_column = *result_column_ptr;
107
        result_column.reserve(column_with_states->size());
108

109
        const auto & states = column_with_states->getData();
110

111
        bool state_created = false;
112
        SCOPE_EXIT_MEMORY_SAFE({
113
            if (state_created)
114
                agg_func.destroy(place.data());
115
        });
116

117
        size_t row_number = 0;
118
        for (const auto & state_to_add : states)
119
        {
120
            if (row_number == 0 || (column_with_groups && column_with_groups->compareAt(row_number, row_number - 1, *column_with_groups, 1) != 0))
121
            {
122
                if (state_created)
123
                {
124
                    agg_func.destroy(place.data());
125
                    state_created = false;
126
                }
127

128
                agg_func.create(place.data()); /// This function can throw.
129
                state_created = true;
130
            }
131

132
            agg_func.merge(place.data(), state_to_add, arena.get());
133
            agg_func.insertResultInto(place.data(), result_column, arena.get());
134

135
            ++row_number;
136
        }
137

138
        return result_column_ptr;
139
    }
140
};
141

142
}
143

144
REGISTER_FUNCTION(RunningAccumulate)
145
{
146
    factory.registerFunction<FunctionRunningAccumulate>();
147
}
148

149
}
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.