ClickHouse
169 строк · 6.1 Кб
1#include <memory>
2#include <type_traits>
3#include <AggregateFunctions/AggregateFunctionAvg.h>
4#include <AggregateFunctions/AggregateFunctionFactory.h>
5#include <AggregateFunctions/Helpers.h>
6#include <AggregateFunctions/FactoryHelpers.h>
7
8
9namespace DB
10{
11
12struct Settings;
13
14namespace ErrorCodes
15{
16extern const int ILLEGAL_TYPE_OF_ARGUMENT;
17}
18
19namespace
20{
21
22template <typename T>
23using AvgWeightedFieldType = std::conditional_t<DecimalOrExtendedInt<T>,
24Float64, // no way to do UInt128 * UInt128, better cast to Float64
25NearestFieldType<T>>;
26
27template <typename T, typename U>
28using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) > sizeof(AvgWeightedFieldType<U>)),
29AvgWeightedFieldType<T>, AvgWeightedFieldType<U>>;
30
31template <typename Value, typename Weight>
32class AggregateFunctionAvgWeighted final :
33public AggregateFunctionAvgBase<
34MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
35{
36public:
37using Base = AggregateFunctionAvgBase<
38MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
39using Base::Base;
40
41using Numerator = typename Base::Numerator;
42using Denominator = typename Base::Denominator;
43using Fraction = typename Base::Fraction;
44
45void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
46{
47const auto & weights = static_cast<const ColumnVector<Weight> &>(*columns[1]);
48
49this->data(place).numerator += static_cast<Numerator>(
50static_cast<const ColumnVector<Value> &>(*columns[0]).getData()[row_num])
51* static_cast<Numerator>(weights.getData()[row_num]);
52
53this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
54}
55
56String getName() const override { return "avgWeighted"; }
57
58#if USE_EMBEDDED_COMPILER
59
60bool isCompilable() const override
61{
62bool can_be_compiled = Base::isCompilable();
63can_be_compiled &= canBeNativeType<Weight>();
64
65return can_be_compiled;
66}
67
68void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
69{
70llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
71
72auto * numerator_type = toNativeType<Numerator>(b);
73auto * numerator_ptr = aggregate_data_ptr;
74auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
75
76auto numerator_data_type = toNativeDataType<Numerator>();
77auto * argument = nativeCast(b, arguments[0], numerator_data_type);
78auto * weight = nativeCast(b, arguments[1], numerator_data_type);
79
80llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
81auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
82b.CreateStore(numerator_result_value, numerator_ptr);
83
84auto * denominator_type = toNativeType<Denominator>(b);
85
86static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
87auto * denominator_ptr = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, denominator_offset);
88
89auto * weight_cast_to_denominator = nativeCast(b, arguments[1], toNativeDataType<Denominator>());
90
91auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
92auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
93
94b.CreateStore(denominator_value_updated, denominator_ptr);
95}
96
97#endif
98
99};
100
101bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept
102{
103const WhichDataType l_dt(left), r_dt(right);
104
105constexpr auto allow = [](WhichDataType t)
106{
107return t.isInt() || t.isUInt() || t.isFloat();
108};
109
110return allow(l_dt) && allow(r_dt);
111}
112
113#define AT_SWITCH(LINE) \
114switch (which.idx) \
115{ \
116LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \
117LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \
118LINE(Float32); LINE(Float64); \
119default: return nullptr; \
120}
121
122template <class First, class ... TArgs>
123IAggregateFunction * create(const IDataType & second_type, TArgs && ... args)
124{
125const WhichDataType which(second_type);
126
127#define LINE(Type) \
128case TypeIndex::Type: return new AggregateFunctionAvgWeighted<First, Type>(std::forward<TArgs>(args)...)
129AT_SWITCH(LINE)
130#undef LINE
131}
132
133// Not using helper functions because there are no templates for binary decimal/numeric function.
134template <class... TArgs>
135IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
136{
137const WhichDataType which(first_type);
138
139#define LINE(Type) \
140case TypeIndex::Type: return create<Type, TArgs...>(second_type, std::forward<TArgs>(args)...)
141AT_SWITCH(LINE)
142#undef LINE
143}
144
145AggregateFunctionPtr
146createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
147{
148assertNoParameters(name, parameters);
149assertBinary(name, argument_types);
150
151const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
152const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
153
154if (!allowTypes(data_type, data_type_weight))
155throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
156"Types {} and {} are non-conforming as arguments for aggregate function {}",
157data_type->getName(), data_type_weight->getName(), name);
158
159return AggregateFunctionPtr(create(*data_type, *data_type_weight, argument_types));
160}
161
162}
163
164void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
165{
166factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted);
167}
168
169}
170