ClickHouse
127 строк · 4.6 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/FactoryHelpers.h>
3
4#include <IO/VarInt.h>
5
6#include <array>
7#include <DataTypes/DataTypesNumber.h>
8#include <DataTypes/DataTypeTuple.h>
9#include <Columns/ColumnNullable.h>
10#include <AggregateFunctions/IAggregateFunction.h>
11#include <AggregateFunctions/Moments.h>
12#include <Common/NaNUtils.h>
13#include <Common/assert_cast.h>
14
15
16namespace DB
17{
18
19namespace ErrorCodes
20{
21extern const int BAD_ARGUMENTS;
22}
23
24namespace
25{
26
27using AggregateFunctionAnalysisOfVarianceData = AnalysisOfVarianceMoments<Float64>;
28
29
30/// One way analysis of variance
31/// Provides a statistical test of whether two or more population means are equal (null hypothesis)
32/// Has an assumption that subjects from group i have normal distribution.
33/// Accepts two arguments - a value and a group number which this value belongs to.
34/// Groups are enumerated starting from 0 and there should be at least two groups to perform a test
35/// Moreover there should be at least one group with the number of observations greater than one.
36class AggregateFunctionAnalysisOfVariance final : public IAggregateFunctionDataHelper<AggregateFunctionAnalysisOfVarianceData, AggregateFunctionAnalysisOfVariance>
37{
38public:
39explicit AggregateFunctionAnalysisOfVariance(const DataTypes & arguments, const Array & params)
40: IAggregateFunctionDataHelper(arguments, params, createResultType())
41{}
42
43DataTypePtr createResultType() const
44{
45DataTypes types {std::make_shared<DataTypeNumber<Float64>>(), std::make_shared<DataTypeNumber<Float64>>() };
46Strings names {"f_statistic", "p_value"};
47return std::make_shared<DataTypeTuple>(
48std::move(types),
49std::move(names)
50);
51}
52
53String getName() const override { return "analysisOfVariance"; }
54
55bool allocatesMemoryInArena() const override { return false; }
56
57void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
58{
59data(place).add(columns[0]->getFloat64(row_num), columns[1]->getUInt(row_num));
60}
61
62void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
63{
64data(place).merge(data(rhs));
65}
66
67void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
68{
69data(place).write(buf);
70}
71
72void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
73{
74data(place).read(buf);
75}
76
77void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
78{
79auto f_stat = data(place).getFStatistic();
80
81auto & column_tuple = assert_cast<ColumnTuple &>(to);
82auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
83auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
84
85if (unlikely(!std::isfinite(f_stat) || f_stat < 0))
86{
87column_stat.getData().push_back(std::numeric_limits<Float64>::quiet_NaN());
88column_value.getData().push_back(std::numeric_limits<Float64>::quiet_NaN());
89return;
90}
91
92auto p_value = data(place).getPValue(f_stat);
93
94/// Because p-value is a probability.
95p_value = std::min(1.0, std::max(0.0, p_value));
96
97column_stat.getData().push_back(f_stat);
98column_value.getData().push_back(p_value);
99}
100
101};
102
103AggregateFunctionPtr createAggregateFunctionAnalysisOfVariance(const std::string & name, const DataTypes & arguments, const Array & parameters, const Settings *)
104{
105assertNoParameters(name, parameters);
106assertBinary(name, arguments);
107
108if (!isNumber(arguments[0]))
109throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical argument types", name);
110if (!WhichDataType(arguments[1]).isNativeUInt())
111throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument of aggregate function {} should be a native unsigned integer", name);
112
113return std::make_shared<AggregateFunctionAnalysisOfVariance>(arguments, parameters);
114}
115
116}
117
118void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory & factory)
119{
120AggregateFunctionProperties properties = { .is_order_dependent = false };
121factory.registerFunction("analysisOfVariance", {createAggregateFunctionAnalysisOfVariance, properties}, AggregateFunctionFactory::CaseInsensitive);
122
123/// This is widely used term
124factory.registerAlias("anova", "analysisOfVariance", AggregateFunctionFactory::CaseInsensitive);
125}
126
127}
128