ClickHouse

Форк
0
/
AggregateFunctionMeanZTest.cpp 
186 строк · 6.2 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/FactoryHelpers.h>
3
#include <AggregateFunctions/Moments.h>
4

5
#include <AggregateFunctions/IAggregateFunction.h>
6
#include <AggregateFunctions/StatCommon.h>
7
#include <Columns/ColumnVector.h>
8
#include <Columns/ColumnTuple.h>
9
#include <Common/assert_cast.h>
10
#include <DataTypes/DataTypesNumber.h>
11
#include <DataTypes/DataTypeTuple.h>
12
#include <cmath>
13

14

15
namespace ErrorCodes
16
{
17
    extern const int BAD_ARGUMENTS;
18
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
19
}
20

21

22
namespace DB
23
{
24
struct Settings;
25

26
namespace
27
{
28

29
/// Returns tuple of (z-statistic, p-value, confidence-interval-low, confidence-interval-high)
30
template <typename Data>
31
class AggregateFunctionMeanZTest :
32
    public IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>
33
{
34
private:
35
    Float64 pop_var_x;
36
    Float64 pop_var_y;
37
    Float64 confidence_level;
38

39
public:
40
    AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)
41
        : IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params, createResultType())
42
    {
43
        pop_var_x = params.at(0).safeGet<Float64>();
44
        pop_var_y = params.at(1).safeGet<Float64>();
45
        confidence_level = params.at(2).safeGet<Float64>();
46

47
        if (!std::isfinite(pop_var_x) || !std::isfinite(pop_var_y) || !std::isfinite(confidence_level))
48
        {
49
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} requires finite parameter values.", Data::name);
50
        }
51

52
        if (pop_var_x < 0.0 || pop_var_y < 0.0)
53
        {
54
            throw Exception(ErrorCodes::BAD_ARGUMENTS,
55
                            "Population variance parameters must be larger than or equal to zero "
56
                            "in aggregate function {}.", Data::name);
57
        }
58

59
        if (confidence_level <= 0.0 || confidence_level >= 1.0)
60
        {
61
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Confidence level parameter must be between 0 and 1 in aggregate function {}.", Data::name);
62
        }
63
    }
64

65
    String getName() const override
66
    {
67
        return Data::name;
68
    }
69

70
    static DataTypePtr createResultType()
71
    {
72
        DataTypes types
73
        {
74
            std::make_shared<DataTypeNumber<Float64>>(),
75
            std::make_shared<DataTypeNumber<Float64>>(),
76
            std::make_shared<DataTypeNumber<Float64>>(),
77
            std::make_shared<DataTypeNumber<Float64>>(),
78
        };
79

80
        Strings names
81
        {
82
            "z_statistic",
83
            "p_value",
84
            "confidence_interval_low",
85
            "confidence_interval_high"
86
        };
87

88
        return std::make_shared<DataTypeTuple>(
89
            std::move(types),
90
            std::move(names)
91
        );
92
    }
93

94
    bool allocatesMemoryInArena() const override { return false; }
95

96
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
97
    {
98
        Float64 value = columns[0]->getFloat64(row_num);
99
        UInt8 is_second = columns[1]->getUInt(row_num);
100

101
        if (is_second)
102
            this->data(place).addY(value);
103
        else
104
            this->data(place).addX(value);
105
    }
106

107
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
108
    {
109
        this->data(place).merge(this->data(rhs));
110
    }
111

112
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
113
    {
114
        this->data(place).write(buf);
115
    }
116

117
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
118
    {
119
        this->data(place).read(buf);
120
    }
121

122
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
123
    {
124
        auto [z_stat, p_value] = this->data(place).getResult(pop_var_x, pop_var_y);
125
        auto [ci_low, ci_high] = this->data(place).getConfidenceIntervals(pop_var_x, pop_var_y, confidence_level);
126

127
        /// Because p-value is a probability.
128
        p_value = std::min(1.0, std::max(0.0, p_value));
129

130
        auto & column_tuple = assert_cast<ColumnTuple &>(to);
131
        auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
132
        auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
133
        auto & column_ci_low = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(2));
134
        auto & column_ci_high = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(3));
135

136
        column_stat.getData().push_back(z_stat);
137
        column_value.getData().push_back(p_value);
138
        column_ci_low.getData().push_back(ci_low);
139
        column_ci_high.getData().push_back(ci_high);
140
    }
141
};
142

143

144
struct MeanZTestData : public ZTestMoments<Float64>
145
{
146
    static constexpr auto name = "meanZTest";
147

148
    std::pair<Float64, Float64> getResult(Float64 pop_var_x, Float64 pop_var_y) const
149
    {
150
        Float64 mean_x = getMeanX();
151
        Float64 mean_y = getMeanY();
152

153
        /// z = \frac{\bar{X_{1}} - \bar{X_{2}}}{\sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}}
154
        Float64 zstat = (mean_x - mean_y) / getStandardError(pop_var_x, pop_var_y);
155

156
        if (unlikely(!std::isfinite(zstat)))
157
            return {std::numeric_limits<Float64>::quiet_NaN(), std::numeric_limits<Float64>::quiet_NaN()};
158

159
        Float64 pvalue = 2.0 * boost::math::cdf(boost::math::normal(0.0, 1.0), -1.0 * std::abs(zstat));
160

161
        return {zstat, pvalue};
162
    }
163
};
164

165
AggregateFunctionPtr createAggregateFunctionMeanZTest(
166
    const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
167
{
168
    assertBinary(name, argument_types);
169

170
    if (parameters.size() != 3)
171
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires three parameter.", name);
172

173
    if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
174
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical types", name);
175

176
    return std::make_shared<AggregateFunctionMeanZTest<MeanZTestData>>(argument_types, parameters);
177
}
178

179
}
180

181
void registerAggregateFunctionMeanZTest(AggregateFunctionFactory & factory)
182
{
183
    factory.registerFunction("meanZTest", createAggregateFunctionMeanZTest);
184
}
185

186
}
187

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.