ClickHouse

Форк
0
/
initializeAggregation.cpp 
169 строк · 6.1 Кб
1
#include <Functions/IFunction.h>
2
#include <Functions/FunctionFactory.h>
3
#include <Functions/FunctionHelpers.h>
4
#include <Columns/ColumnString.h>
5
#include <Columns/ColumnAggregateFunction.h>
6
#include <AggregateFunctions/AggregateFunctionFactory.h>
7
#include <AggregateFunctions/Combinators/AggregateFunctionState.h>
8
#include <AggregateFunctions/IAggregateFunction.h>
9
#include <AggregateFunctions/parseAggregateFunctionParameters.h>
10
#include <Common/Arena.h>
11

12
#include <Common/scope_guard_safe.h>
13

14

15
namespace DB
16
{
17
namespace ErrorCodes
18
{
19
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
20
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
21
    extern const int BAD_ARGUMENTS;
22
}
23

24
namespace
25
{
26

27
class FunctionInitializeAggregation : public IFunction, private WithContext
28
{
29
public:
30
    static constexpr auto name = "initializeAggregation";
31
    static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionInitializeAggregation>(context_); }
32
    explicit FunctionInitializeAggregation(ContextPtr context_) : WithContext(context_) {}
33

34
    String getName() const override { return name; }
35

36
    bool isVariadic() const override { return true; }
37
    size_t getNumberOfArguments() const override { return 0; }
38

39
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
40

41
    bool useDefaultImplementationForConstants() const override { return true; }
42
    bool useDefaultImplementationForNulls() const override { return false; }
43
    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; }
44

45
    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
46

47
    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
48

49
private:
50
    /// TODO Rewrite with FunctionBuilder.
51
    mutable AggregateFunctionPtr aggregate_function;
52
};
53

54

55
DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
56
{
57
    if (arguments.size() < 2)
58
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
59
            "Number of arguments for function {} doesn't match: passed {}, should be at least 2.",
60
            getName(), arguments.size());
61

62
    const ColumnConst * aggregate_function_name_column = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
63
    if (!aggregate_function_name_column)
64
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be constant string: "
65
            "name of aggregate function.", getName());
66

67
    DataTypes argument_types(arguments.size() - 1);
68
    for (size_t i = 1, size = arguments.size(); i < size; ++i)
69
    {
70
        argument_types[i - 1] = arguments[i].type;
71
    }
72

73
    if (!aggregate_function)
74
    {
75
        String aggregate_function_name_with_params = aggregate_function_name_column->getValue<String>();
76

77
        if (aggregate_function_name_with_params.empty())
78
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "First argument for function {} (name of aggregate function) cannot be empty.", getName());
79

80
        String aggregate_function_name;
81
        Array params_row;
82
        getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
83
                                                   aggregate_function_name, params_row, "function " + getName(), getContext());
84

85
        auto action = NullsAction::EMPTY; /// It is already embedded in the function name itself
86
        AggregateFunctionProperties properties;
87
        aggregate_function
88
            = AggregateFunctionFactory::instance().get(aggregate_function_name, action, argument_types, params_row, properties);
89
    }
90

91
    return aggregate_function->getResultType();
92
}
93

94

95
ColumnPtr FunctionInitializeAggregation::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
96
{
97
    const IAggregateFunction & agg_func = *aggregate_function;
98
    std::unique_ptr<Arena> arena = std::make_unique<Arena>();
99

100
    const size_t num_arguments_columns = arguments.size() - 1;
101

102
    std::vector<ColumnPtr> materialized_columns(num_arguments_columns);
103
    std::vector<const IColumn *> aggregate_arguments_vec(num_arguments_columns);
104

105
    for (size_t i = 0; i < num_arguments_columns; ++i)
106
    {
107
        const IColumn * col = arguments[i + 1].column.get();
108
        materialized_columns.emplace_back(col->convertToFullColumnIfConst());
109
        aggregate_arguments_vec[i] = &(*materialized_columns.back());
110
    }
111

112
    const IColumn ** aggregate_arguments = aggregate_arguments_vec.data();
113

114
    MutableColumnPtr result_holder = result_type->createColumn();
115
    IColumn & res_col = *result_holder;
116

117
    PODArray<AggregateDataPtr> places(input_rows_count);
118
    for (size_t i = 0; i < input_rows_count; ++i)
119
    {
120
        places[i] = arena->alignedAlloc(agg_func.sizeOfData(), agg_func.alignOfData());
121
        try
122
        {
123
            agg_func.create(places[i]);
124
        }
125
        catch (...)
126
        {
127
            for (size_t j = 0; j < i; ++j)
128
                agg_func.destroy(places[j]);
129
            throw;
130
        }
131
    }
132

133
    SCOPE_EXIT_MEMORY_SAFE({
134
        for (size_t i = 0; i < input_rows_count; ++i)
135
            agg_func.destroy(places[i]);
136
    });
137

138
    {
139
        const auto * that = &agg_func;
140
        /// Unnest consecutive trailing -State combinators
141
        while (const auto * func = typeid_cast<const AggregateFunctionState *>(that))
142
            that = func->getNestedFunction().get();
143
        that->addBatch(0, input_rows_count, places.data(), 0, aggregate_arguments, arena.get());
144
    }
145

146
    if (agg_func.isState())
147
    {
148
        /// We should use insertMergeResultInto to insert result into ColumnAggregateFunction
149
        /// correctly if result contains AggregateFunction's states
150
        for (size_t i = 0; i < input_rows_count; ++i)
151
            agg_func.insertMergeResultInto(places[i], res_col, arena.get());
152
    }
153
    else
154
    {
155
        for (size_t i = 0; i < input_rows_count; ++i)
156
            agg_func.insertResultInto(places[i], res_col, arena.get());
157
    }
158

159
    return result_holder;
160
}
161

162
}
163

164
REGISTER_FUNCTION(InitializeAggregation)
165
{
166
    factory.registerFunction<FunctionInitializeAggregation>();
167
}
168

169
}
170

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.