ClickHouse

Форк
0
/
AggregateFunctionAnalysisOfVariance.cpp 
127 строк · 4.6 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/FactoryHelpers.h>
3

4
#include <IO/VarInt.h>
5

6
#include <array>
7
#include <DataTypes/DataTypesNumber.h>
8
#include <DataTypes/DataTypeTuple.h>
9
#include <Columns/ColumnNullable.h>
10
#include <AggregateFunctions/IAggregateFunction.h>
11
#include <AggregateFunctions/Moments.h>
12
#include <Common/NaNUtils.h>
13
#include <Common/assert_cast.h>
14

15

16
namespace DB
17
{
18

19
namespace ErrorCodes
20
{
21
    extern const int BAD_ARGUMENTS;
22
}
23

24
namespace
25
{
26

27
using AggregateFunctionAnalysisOfVarianceData = AnalysisOfVarianceMoments<Float64>;
28

29

30
/// One way analysis of variance
31
/// Provides a statistical test of whether two or more population means are equal (null hypothesis)
32
/// Has an assumption that subjects from group i have normal distribution.
33
/// Accepts two arguments - a value and a group number which this value belongs to.
34
/// Groups are enumerated starting from 0 and there should be at least two groups to perform a test
35
/// Moreover there should be at least one group with the number of observations greater than one.
36
class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataHelper<AggregateFunctionAnalysisOfVarianceData, AggregateFunctionAnalysisOfVariance>
37
{
38
public:
39
    explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
40
        : IAggregateFunctionDataHelper(arguments, params, createResultType())
41
    {}
42

43
    DataTypePtr createResultType() const
44
    {
45
        DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
46
        Strings names {"f_statistic", "p_value"};
47
        return std::make_shared<DataTypeTuple>(
48
            std::move(types),
49
            std::move(names)
50
        );
51
    }
52

53
    String getName() const override { return "analysisOfVariance"; }
54

55
    bool allocatesMemoryInArena() const override { return false; }
56

57
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
58
    {
59
        data(place).add(columns[0]->getFloat64(row_num), columns[1]->getUInt(row_num));
60
    }
61

62
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
63
    {
64
        data(place).merge(data(rhs));
65
    }
66

67
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
68
    {
69
        data(place).write(buf);
70
    }
71

72
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
73
    {
74
        data(place).read(buf);
75
    }
76

77
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
78
    {
79
        auto f_stat = data(place).getFStatistic();
80

81
        auto & column_tuple = assert_cast<ColumnTuple &>(to);
82
        auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
83
        auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
84

85
        if (unlikely(!std::isfinite(f_stat) || f_stat < 0))
86
        {
87
            column_stat.getData().push_back(std::numeric_limits<Float64>::quiet_NaN());
88
            column_value.getData().push_back(std::numeric_limits<Float64>::quiet_NaN());
89
            return;
90
        }
91

92
        auto p_value = data(place).getPValue(f_stat);
93

94
        /// Because p-value is a probability.
95
        p_value = std::min(1.0, std::max(0.0, p_value));
96

97
        column_stat.getData().push_back(f_stat);
98
        column_value.getData().push_back(p_value);
99
    }
100

101
};
102

103
AggregateFunctionPtr createAggregateFunctionAnalysisOfVariance(const std::string & name, const DataTypes & arguments, const Array & parameters, const Settings *)
104
{
105
    assertNoParameters(name, parameters);
106
    assertBinary(name, arguments);
107

108
    if (!isNumber(arguments[0]))
109
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical argument types", name);
110
    if (!WhichDataType(arguments[1]).isNativeUInt())
111
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument of aggregate function {} should be a native unsigned integer", name);
112

113
    return std::make_shared<AggregateFunctionAnalysisOfVariance>(arguments, parameters);
114
}
115

116
}
117

118
void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory & factory)
119
{
120
    AggregateFunctionProperties properties = { .is_order_dependent = false };
121
    factory.registerFunction("analysisOfVariance", {createAggregateFunctionAnalysisOfVariance, properties}, AggregateFunctionFactory::CaseInsensitive);
122

123
    /// This is widely used term
124
    factory.registerAlias("anova", "analysisOfVariance", AggregateFunctionFactory::CaseInsensitive);
125
}
126

127
}
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.