ClickHouse

Форк
0
225 строк · 8.8 Кб
1
#include <Columns/ColumnNullable.h>
2
#include <Columns/ColumnString.h>
3
#include <Columns/ColumnTuple.h>
4
#include <Columns/ColumnsNumber.h>
5
#include <Columns/IColumn.h>
6
#include <DataTypes/DataTypeTuple.h>
7
#include <DataTypes/DataTypesNumber.h>
8
#include <Functions/FunctionFactory.h>
9
#include <Functions/FunctionHelpers.h>
10
#include <Functions/IFunction.h>
11
#include <Functions/castTypeToEither.h>
12
#include <Interpreters/castColumn.h>
13
#include <boost/math/distributions/normal.hpp>
14
#include <Common/typeid_cast.h>
15

16

17
namespace DB
18
{
19

20
namespace ErrorCodes
21
{
22
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
23
    extern const int BAD_ARGUMENTS;
24
}
25

26

27
class FunctionTwoSampleProportionsZTest : public IFunction
28
{
29
public:
30
    static constexpr auto POOLED = "pooled";
31
    static constexpr auto UNPOOLED = "unpooled";
32

33
    static constexpr auto name = "proportionsZTest";
34

35
    static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionTwoSampleProportionsZTest>(); }
36

37
    String getName() const override { return name; }
38

39
    size_t getNumberOfArguments() const override { return 6; }
40
    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {5}; }
41

42
    bool useDefaultImplementationForNulls() const override { return false; }
43
    bool useDefaultImplementationForConstants() const override { return true; }
44
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
45

46
    static DataTypePtr getReturnType()
47
    {
48
        auto float_data_type = std::make_shared<DataTypeNumber<Float64>>();
49
        DataTypes types(4, float_data_type);
50

51
        Strings names{"z_statistic", "p_value", "confidence_interval_low", "confidence_interval_high"};
52

53
        return std::make_shared<DataTypeTuple>(std::move(types), std::move(names));
54
    }
55

56
    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
57
    {
58
        for (size_t i = 0; i < 4; ++i)
59
        {
60
            if (!isUInt(arguments[i].type))
61
            {
62
                throw Exception(
63
                    ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
64
                    "The {}th Argument of function {} must be an unsigned integer.",
65
                    i + 1,
66
                    getName());
67
            }
68
        }
69

70
        if (!isFloat(arguments[4].type))
71
        {
72
            throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
73
                "The fifth argument {} of function {} should be a float,",
74
                arguments[4].type->getName(),
75
                getName()};
76
        }
77

78
        /// There is an additional check for constancy in ExecuteImpl
79
        if (!isString(arguments[5].type) || !arguments[5].column)
80
        {
81
            throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
82
                "The sixth argument {} of function {} should be a constant string",
83
                arguments[5].type->getName(),
84
                getName()};
85
        }
86

87
        return getReturnType();
88
    }
89

90

91
    ColumnPtr executeImpl(const ColumnsWithTypeAndName & const_arguments, const DataTypePtr &, size_t input_rows_count) const override
92
    {
93
        auto arguments = const_arguments;
94
        /// Only last argument have to be constant
95
        for (size_t i = 0; i < 5; ++i)
96
            arguments[i].column = arguments[i].column->convertToFullColumnIfConst();
97

98
        static const auto uint64_data_type = std::make_shared<DataTypeNumber<UInt64>>();
99

100
        auto column_successes_x = castColumnAccurate(arguments[0], uint64_data_type);
101
        const auto & data_successes_x = checkAndGetColumn<ColumnVector<UInt64>>(column_successes_x.get())->getData();
102

103
        auto column_successes_y = castColumnAccurate(arguments[1], uint64_data_type);
104
        const auto & data_successes_y = checkAndGetColumn<ColumnVector<UInt64>>(column_successes_y.get())->getData();
105

106
        auto column_trials_x = castColumnAccurate(arguments[2], uint64_data_type);
107
        const auto & data_trials_x = checkAndGetColumn<ColumnVector<UInt64>>(column_trials_x.get())->getData();
108

109
        auto column_trials_y = castColumnAccurate(arguments[3], uint64_data_type);
110
        const auto & data_trials_y = checkAndGetColumn<ColumnVector<UInt64>>(column_trials_y.get())->getData();
111

112
        static const auto float64_data_type = std::make_shared<DataTypeNumber<Float64>>();
113

114
        auto column_confidence_level = castColumnAccurate(arguments[4], float64_data_type);
115
        const auto & data_confidence_level = checkAndGetColumn<ColumnVector<Float64>>(column_confidence_level.get())->getData();
116

117
        String usevar = checkAndGetColumnConst<ColumnString>(arguments[5].column.get())->getValue<String>();
118

119
        if (usevar != UNPOOLED && usevar != POOLED)
120
            throw Exception{ErrorCodes::BAD_ARGUMENTS,
121
                "The sixth argument {} of function {} must be equal to `pooled` or `unpooled`",
122
                arguments[5].type->getName(),
123
                getName()};
124

125
        const bool is_unpooled = (usevar == UNPOOLED);
126

127
        auto res_z_statistic = ColumnFloat64::create();
128
        auto & data_z_statistic = res_z_statistic->getData();
129
        data_z_statistic.reserve(input_rows_count);
130

131
        auto res_p_value = ColumnFloat64::create();
132
        auto & data_p_value = res_p_value->getData();
133
        data_p_value.reserve(input_rows_count);
134

135
        auto res_ci_lower = ColumnFloat64::create();
136
        auto & data_ci_lower = res_ci_lower->getData();
137
        data_ci_lower.reserve(input_rows_count);
138

139
        auto res_ci_upper = ColumnFloat64::create();
140
        auto & data_ci_upper = res_ci_upper->getData();
141
        data_ci_upper.reserve(input_rows_count);
142

143
        auto insert_values_into_result = [&data_z_statistic, &data_p_value, &data_ci_lower, &data_ci_upper](
144
                                             Float64 z_stat, Float64 p_value, Float64 lower, Float64 upper)
145
        {
146
            data_z_statistic.emplace_back(z_stat);
147
            data_p_value.emplace_back(p_value);
148
            data_ci_lower.emplace_back(lower);
149
            data_ci_upper.emplace_back(upper);
150
        };
151

152
        static constexpr Float64 nan = std::numeric_limits<Float64>::quiet_NaN();
153

154
        boost::math::normal_distribution<> nd(0.0, 1.0);
155

156
        for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
157
        {
158
            const UInt64 successes_x = data_successes_x[row_num];
159
            const UInt64 successes_y = data_successes_y[row_num];
160
            const UInt64 trials_x = data_trials_x[row_num];
161
            const UInt64 trials_y = data_trials_y[row_num];
162
            const Float64 confidence_level = data_confidence_level[row_num];
163

164
            const Float64 props_x = static_cast<Float64>(successes_x) / trials_x;
165
            const Float64 props_y = static_cast<Float64>(successes_y) / trials_y;
166
            const Float64 diff = props_x - props_y;
167
            const UInt64 trials_total = trials_x + trials_y;
168

169
            if (successes_x == 0 || successes_y == 0 || successes_x > trials_x || successes_y > trials_y || trials_total == 0
170
                || !std::isfinite(confidence_level) || confidence_level < 0.0 || confidence_level > 1.0)
171
            {
172
                insert_values_into_result(nan, nan, nan, nan);
173
                continue;
174
            }
175

176
            Float64 se = std::sqrt(props_x * (1.0 - props_x) / trials_x + props_y * (1.0 - props_y) / trials_y);
177

178
            /// z-statistics
179
            /// z = \frac{ \bar{p_{1}} - \bar{p_{2}} }{ \sqrt{ \frac{ \bar{p_{1}} \left ( 1 - \bar{p_{1}} \right ) }{ n_{1} } \frac{ \bar{p_{2}} \left ( 1 - \bar{p_{2}} \right ) }{ n_{2} } } }
180
            Float64 zstat;
181
            if (is_unpooled)
182
            {
183
                zstat = (props_x - props_y) / se;
184
            }
185
            else
186
            {
187
                UInt64 successes_total = successes_x + successes_y;
188
                Float64 p_pooled = static_cast<Float64>(successes_total) / trials_total;
189
                Float64 trials_fact = 1.0 / trials_x + 1.0 / trials_y;
190
                zstat = diff / std::sqrt(p_pooled * (1.0 - p_pooled) * trials_fact);
191
            }
192

193
            if (unlikely(!std::isfinite(zstat)))
194
            {
195
                insert_values_into_result(nan, nan, nan, nan);
196
                continue;
197
            }
198

199
            // pvalue
200
            Float64 pvalue = 0;
201
            Float64 one_side = 1 - boost::math::cdf(nd, std::abs(zstat));
202
            pvalue = one_side * 2;
203

204
            // Confidence intervals
205
            Float64 d = props_x - props_y;
206
            Float64 z = -boost::math::quantile(nd, (1.0 - confidence_level) / 2.0);
207
            Float64 dist = z * se;
208
            Float64 ci_low = d - dist;
209
            Float64 ci_high = d + dist;
210

211
            insert_values_into_result(zstat, pvalue, ci_low, ci_high);
212
        }
213

214
        return ColumnTuple::create(
215
            Columns{std::move(res_z_statistic), std::move(res_p_value), std::move(res_ci_lower), std::move(res_ci_upper)});
216
    }
217
};
218

219

220
REGISTER_FUNCTION(ZTest)
221
{
222
    factory.registerFunction<FunctionTwoSampleProportionsZTest>();
223
}
224

225
}
226

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.