ClickHouse
149 строк · 4.9 Кб
1#include <AggregateFunctions/IAggregateFunction.h>2#include <Columns/ColumnAggregateFunction.h>3#include <DataTypes/DataTypeAggregateFunction.h>4#include <Functions/FunctionFactory.h>5#include <Functions/FunctionHelpers.h>6#include <Functions/IFunction.h>7#include <Common/AlignedBuffer.h>8#include <Common/Arena.h>9#include <Common/scope_guard_safe.h>10
11
12namespace DB13{
14namespace ErrorCodes15{
16extern const int ILLEGAL_COLUMN;17extern const int ILLEGAL_TYPE_OF_ARGUMENT;18extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;19}
20
21namespace
22{
23
24/** runningAccumulate(agg_state) - takes the states of the aggregate function and returns a column with values,
25* are the result of the accumulation of these states for a set of columns lines, from the first to the current line.
26*
27* Quite unusual function.
28* Takes state of aggregate function (example runningAccumulate(uniqState(UserID))),
29* and for each row of columns, return result of aggregate function on merge of states of all previous rows and current row.
30*
31* So, result of function depends on partition of data to columns and on order of data in columns.
32*/
33class FunctionRunningAccumulate : public IFunction34{
35public:36static constexpr auto name = "runningAccumulate";37static FunctionPtr create(ContextPtr)38{39return std::make_shared<FunctionRunningAccumulate>();40}41
42String getName() const override43{44return name;45}46
47bool isStateful() const override48{49return true;50}51
52bool isVariadic() const override { return true; }53
54size_t getNumberOfArguments() const override { return 0; }55
56bool isDeterministic() const override57{58return false;59}60
61bool isDeterministicInScopeOfQuery() const override62{63return false;64}65
66bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }67
68DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override69{70if (arguments.empty() || arguments.size() > 2)71throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,72"Incorrect number of arguments of function {}. Must be 1 or 2.", getName());73
74const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());75if (!type)76throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,77"Argument for function {} must have type AggregateFunction - state "78"of aggregate function.", getName());79
80return type->getReturnType();81}82
83ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override84{85const ColumnAggregateFunction * column_with_states86= typeid_cast<const ColumnAggregateFunction *>(&*arguments.at(0).column);87
88if (!column_with_states)89throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}",90arguments.at(0).column->getName(), getName());91
92ColumnPtr column_with_groups;93
94if (arguments.size() == 2)95column_with_groups = arguments[1].column;96
97AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();98const IAggregateFunction & agg_func = *aggregate_function_ptr;99
100AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData());101
102/// Will pass empty arena if agg_func does not allocate memory in arena103std::unique_ptr<Arena> arena = agg_func.allocatesMemoryInArena() ? std::make_unique<Arena>() : nullptr;104
105auto result_column_ptr = agg_func.getResultType()->createColumn();106IColumn & result_column = *result_column_ptr;107result_column.reserve(column_with_states->size());108
109const auto & states = column_with_states->getData();110
111bool state_created = false;112SCOPE_EXIT_MEMORY_SAFE({113if (state_created)114agg_func.destroy(place.data());115});116
117size_t row_number = 0;118for (const auto & state_to_add : states)119{120if (row_number == 0 || (column_with_groups && column_with_groups->compareAt(row_number, row_number - 1, *column_with_groups, 1) != 0))121{122if (state_created)123{124agg_func.destroy(place.data());125state_created = false;126}127
128agg_func.create(place.data()); /// This function can throw.129state_created = true;130}131
132agg_func.merge(place.data(), state_to_add, arena.get());133agg_func.insertResultInto(place.data(), result_column, arena.get());134
135++row_number;136}137
138return result_column_ptr;139}140};141
142}
143
144REGISTER_FUNCTION(RunningAccumulate)145{
146factory.registerFunction<FunctionRunningAccumulate>();147}
148
149}
150