ClickHouse

Форк
0
/
AggregateFunctionAvgWeighted.cpp 
169 строк · 6.1 Кб
1
#include <memory>
2
#include <type_traits>
3
#include <AggregateFunctions/AggregateFunctionAvg.h>
4
#include <AggregateFunctions/AggregateFunctionFactory.h>
5
#include <AggregateFunctions/Helpers.h>
6
#include <AggregateFunctions/FactoryHelpers.h>
7

8

9
namespace DB
10
{
11

12
struct Settings;
13

14
namespace ErrorCodes
15
{
16
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
17
}
18

19
namespace
20
{
21

22
template <typename T>
23
using AvgWeightedFieldType = std::conditional_t<DecimalOrExtendedInt<T>,
24
        Float64, // no way to do UInt128 * UInt128, better cast to Float64
25
        NearestFieldType<T>>;
26

27
template <typename T, typename U>
28
using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) > sizeof(AvgWeightedFieldType<U>)),
29
    AvgWeightedFieldType<T>, AvgWeightedFieldType<U>>;
30

31
template <typename Value, typename Weight>
32
class AggregateFunctionAvgWeighted final :
33
    public AggregateFunctionAvgBase<
34
        MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
35
{
36
public:
37
    using Base = AggregateFunctionAvgBase<
38
        MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
39
    using Base::Base;
40

41
    using Numerator = typename Base::Numerator;
42
    using Denominator = typename Base::Denominator;
43
    using Fraction = typename Base::Fraction;
44

45
    void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
46
    {
47
        const auto & weights = static_cast<const ColumnVector<Weight> &>(*columns[1]);
48

49
        this->data(place).numerator += static_cast<Numerator>(
50
            static_cast<const ColumnVector<Value> &>(*columns[0]).getData()[row_num])
51
            * static_cast<Numerator>(weights.getData()[row_num]);
52

53
        this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
54
    }
55

56
    String getName() const override { return "avgWeighted"; }
57

58
#if USE_EMBEDDED_COMPILER
59

60
    bool isCompilable() const override
61
    {
62
        bool can_be_compiled = Base::isCompilable();
63
        can_be_compiled &= canBeNativeType<Weight>();
64

65
        return can_be_compiled;
66
    }
67

68
    void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override
69
    {
70
        llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
71

72
        auto * numerator_type = toNativeType<Numerator>(b);
73
        auto * numerator_ptr = aggregate_data_ptr;
74
        auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
75

76
        auto numerator_data_type = toNativeDataType<Numerator>();
77
        auto * argument = nativeCast(b, arguments[0], numerator_data_type);
78
        auto * weight = nativeCast(b, arguments[1], numerator_data_type);
79

80
        llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
81
        auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
82
        b.CreateStore(numerator_result_value, numerator_ptr);
83

84
        auto * denominator_type = toNativeType<Denominator>(b);
85

86
        static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
87
        auto * denominator_ptr = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, denominator_offset);
88

89
        auto * weight_cast_to_denominator = nativeCast(b, arguments[1], toNativeDataType<Denominator>());
90

91
        auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
92
        auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
93

94
        b.CreateStore(denominator_value_updated, denominator_ptr);
95
    }
96

97
#endif
98

99
};
100

101
bool allowTypes(const DataTypePtr& left, const DataTypePtr& right) noexcept
102
{
103
    const WhichDataType l_dt(left), r_dt(right);
104

105
    constexpr auto allow = [](WhichDataType t)
106
    {
107
        return t.isInt() || t.isUInt() || t.isFloat();
108
    };
109

110
    return allow(l_dt) && allow(r_dt);
111
}
112

113
#define AT_SWITCH(LINE) \
114
    switch (which.idx) \
115
    { \
116
        LINE(Int8); LINE(Int16); LINE(Int32); LINE(Int64); LINE(Int128); LINE(Int256); \
117
        LINE(UInt8); LINE(UInt16); LINE(UInt32); LINE(UInt64); LINE(UInt128); LINE(UInt256); \
118
        LINE(Float32); LINE(Float64); \
119
        default: return nullptr; \
120
    }
121

122
template <class First, class ... TArgs>
123
IAggregateFunction * create(const IDataType & second_type, TArgs && ... args)
124
{
125
    const WhichDataType which(second_type);
126

127
#define LINE(Type) \
128
    case TypeIndex::Type:       return new AggregateFunctionAvgWeighted<First, Type>(std::forward<TArgs>(args)...)
129
    AT_SWITCH(LINE)
130
#undef LINE
131
}
132

133
// Not using helper functions because there are no templates for binary decimal/numeric function.
134
template <class... TArgs>
135
IAggregateFunction * create(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
136
{
137
    const WhichDataType which(first_type);
138

139
#define LINE(Type) \
140
    case TypeIndex::Type:       return create<Type, TArgs...>(second_type, std::forward<TArgs>(args)...)
141
    AT_SWITCH(LINE)
142
#undef LINE
143
}
144

145
AggregateFunctionPtr
146
createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
147
{
148
    assertNoParameters(name, parameters);
149
    assertBinary(name, argument_types);
150

151
    const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
152
    const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
153

154
    if (!allowTypes(data_type, data_type_weight))
155
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
156
                        "Types {} and {} are non-conforming as arguments for aggregate function {}",
157
                        data_type->getName(), data_type_weight->getName(), name);
158

159
    return AggregateFunctionPtr(create(*data_type, *data_type_weight, argument_types));
160
}
161

162
}
163

164
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
165
{
166
    factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted);
167
}
168

169
}
170

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.