ClickHouse
169 строк · 6.1 Кб
1#include <Functions/IFunction.h>
2#include <Functions/FunctionFactory.h>
3#include <Functions/FunctionHelpers.h>
4#include <Columns/ColumnString.h>
5#include <Columns/ColumnAggregateFunction.h>
6#include <AggregateFunctions/AggregateFunctionFactory.h>
7#include <AggregateFunctions/Combinators/AggregateFunctionState.h>
8#include <AggregateFunctions/IAggregateFunction.h>
9#include <AggregateFunctions/parseAggregateFunctionParameters.h>
10#include <Common/Arena.h>
11
12#include <Common/scope_guard_safe.h>
13
14
15namespace DB
16{
17namespace ErrorCodes
18{
19extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
20extern const int ILLEGAL_TYPE_OF_ARGUMENT;
21extern const int BAD_ARGUMENTS;
22}
23
24namespace
25{
26
27class FunctionInitializeAggregation : public IFunction, private WithContext
28{
29public:
30static constexpr auto name = "initializeAggregation";
31static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionInitializeAggregation>(context_); }
32explicit FunctionInitializeAggregation(ContextPtr context_) : WithContext(context_) {}
33
34String getName() const override { return name; }
35
36bool isVariadic() const override { return true; }
37size_t getNumberOfArguments() const override { return 0; }
38
39bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
40
41bool useDefaultImplementationForConstants() const override { return true; }
42bool useDefaultImplementationForNulls() const override { return false; }
43ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; }
44
45DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
46
47ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
48
49private:
50/// TODO Rewrite with FunctionBuilder.
51mutable AggregateFunctionPtr aggregate_function;
52};
53
54
55DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
56{
57if (arguments.size() < 2)
58throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
59"Number of arguments for function {} doesn't match: passed {}, should be at least 2.",
60getName(), arguments.size());
61
62const ColumnConst * aggregate_function_name_column = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
63if (!aggregate_function_name_column)
64throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be constant string: "
65"name of aggregate function.", getName());
66
67DataTypes argument_types(arguments.size() - 1);
68for (size_t i = 1, size = arguments.size(); i < size; ++i)
69{
70argument_types[i - 1] = arguments[i].type;
71}
72
73if (!aggregate_function)
74{
75String aggregate_function_name_with_params = aggregate_function_name_column->getValue<String>();
76
77if (aggregate_function_name_with_params.empty())
78throw Exception(ErrorCodes::BAD_ARGUMENTS, "First argument for function {} (name of aggregate function) cannot be empty.", getName());
79
80String aggregate_function_name;
81Array params_row;
82getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
83aggregate_function_name, params_row, "function " + getName(), getContext());
84
85auto action = NullsAction::EMPTY; /// It is already embedded in the function name itself
86AggregateFunctionProperties properties;
87aggregate_function
88= AggregateFunctionFactory::instance().get(aggregate_function_name, action, argument_types, params_row, properties);
89}
90
91return aggregate_function->getResultType();
92}
93
94
95ColumnPtr FunctionInitializeAggregation::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
96{
97const IAggregateFunction & agg_func = *aggregate_function;
98std::unique_ptr<Arena> arena = std::make_unique<Arena>();
99
100const size_t num_arguments_columns = arguments.size() - 1;
101
102std::vector<ColumnPtr> materialized_columns(num_arguments_columns);
103std::vector<const IColumn *> aggregate_arguments_vec(num_arguments_columns);
104
105for (size_t i = 0; i < num_arguments_columns; ++i)
106{
107const IColumn * col = arguments[i + 1].column.get();
108materialized_columns.emplace_back(col->convertToFullColumnIfConst());
109aggregate_arguments_vec[i] = &(*materialized_columns.back());
110}
111
112const IColumn ** aggregate_arguments = aggregate_arguments_vec.data();
113
114MutableColumnPtr result_holder = result_type->createColumn();
115IColumn & res_col = *result_holder;
116
117PODArray<AggregateDataPtr> places(input_rows_count);
118for (size_t i = 0; i < input_rows_count; ++i)
119{
120places[i] = arena->alignedAlloc(agg_func.sizeOfData(), agg_func.alignOfData());
121try
122{
123agg_func.create(places[i]);
124}
125catch (...)
126{
127for (size_t j = 0; j < i; ++j)
128agg_func.destroy(places[j]);
129throw;
130}
131}
132
133SCOPE_EXIT_MEMORY_SAFE({
134for (size_t i = 0; i < input_rows_count; ++i)
135agg_func.destroy(places[i]);
136});
137
138{
139const auto * that = &agg_func;
140/// Unnest consecutive trailing -State combinators
141while (const auto * func = typeid_cast<const AggregateFunctionState *>(that))
142that = func->getNestedFunction().get();
143that->addBatch(0, input_rows_count, places.data(), 0, aggregate_arguments, arena.get());
144}
145
146if (agg_func.isState())
147{
148/// We should use insertMergeResultInto to insert result into ColumnAggregateFunction
149/// correctly if result contains AggregateFunction's states
150for (size_t i = 0; i < input_rows_count; ++i)
151agg_func.insertMergeResultInto(places[i], res_col, arena.get());
152}
153else
154{
155for (size_t i = 0; i < input_rows_count; ++i)
156agg_func.insertResultInto(places[i], res_col, arena.get());
157}
158
159return result_holder;
160}
161
162}
163
164REGISTER_FUNCTION(InitializeAggregation)
165{
166factory.registerFunction<FunctionInitializeAggregation>();
167}
168
169}
170