ClickHouse
225 строк · 8.8 Кб
1#include <Columns/ColumnNullable.h>
2#include <Columns/ColumnString.h>
3#include <Columns/ColumnTuple.h>
4#include <Columns/ColumnsNumber.h>
5#include <Columns/IColumn.h>
6#include <DataTypes/DataTypeTuple.h>
7#include <DataTypes/DataTypesNumber.h>
8#include <Functions/FunctionFactory.h>
9#include <Functions/FunctionHelpers.h>
10#include <Functions/IFunction.h>
11#include <Functions/castTypeToEither.h>
12#include <Interpreters/castColumn.h>
13#include <boost/math/distributions/normal.hpp>
14#include <Common/typeid_cast.h>
15
16
17namespace DB
18{
19
20namespace ErrorCodes
21{
22extern const int ILLEGAL_TYPE_OF_ARGUMENT;
23extern const int BAD_ARGUMENTS;
24}
25
26
27class FunctionTwoSampleProportionsZTest : public IFunction
28{
29public:
30static constexpr auto POOLED = "pooled";
31static constexpr auto UNPOOLED = "unpooled";
32
33static constexpr auto name = "proportionsZTest";
34
35static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionTwoSampleProportionsZTest>(); }
36
37String getName() const override { return name; }
38
39size_t getNumberOfArguments() const override { return 6; }
40ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {5}; }
41
42bool useDefaultImplementationForNulls() const override { return false; }
43bool useDefaultImplementationForConstants() const override { return true; }
44bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
45
46static DataTypePtr getReturnType()
47{
48auto float_data_type = std::make_shared<DataTypeNumber<Float64>>();
49DataTypes types(4, float_data_type);
50
51Strings names{"z_statistic", "p_value", "confidence_interval_low", "confidence_interval_high"};
52
53return std::make_shared<DataTypeTuple>(std::move(types), std::move(names));
54}
55
56DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
57{
58for (size_t i = 0; i < 4; ++i)
59{
60if (!isUInt(arguments[i].type))
61{
62throw Exception(
63ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
64"The {}th Argument of function {} must be an unsigned integer.",
65i + 1,
66getName());
67}
68}
69
70if (!isFloat(arguments[4].type))
71{
72throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
73"The fifth argument {} of function {} should be a float,",
74arguments[4].type->getName(),
75getName()};
76}
77
78/// There is an additional check for constancy in ExecuteImpl
79if (!isString(arguments[5].type) || !arguments[5].column)
80{
81throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
82"The sixth argument {} of function {} should be a constant string",
83arguments[5].type->getName(),
84getName()};
85}
86
87return getReturnType();
88}
89
90
91ColumnPtr executeImpl(const ColumnsWithTypeAndName & const_arguments, const DataTypePtr &, size_t input_rows_count) const override
92{
93auto arguments = const_arguments;
94/// Only last argument have to be constant
95for (size_t i = 0; i < 5; ++i)
96arguments[i].column = arguments[i].column->convertToFullColumnIfConst();
97
98static const auto uint64_data_type = std::make_shared<DataTypeNumber<UInt64>>();
99
100auto column_successes_x = castColumnAccurate(arguments[0], uint64_data_type);
101const auto & data_successes_x = checkAndGetColumn<ColumnVector<UInt64>>(column_successes_x.get())->getData();
102
103auto column_successes_y = castColumnAccurate(arguments[1], uint64_data_type);
104const auto & data_successes_y = checkAndGetColumn<ColumnVector<UInt64>>(column_successes_y.get())->getData();
105
106auto column_trials_x = castColumnAccurate(arguments[2], uint64_data_type);
107const auto & data_trials_x = checkAndGetColumn<ColumnVector<UInt64>>(column_trials_x.get())->getData();
108
109auto column_trials_y = castColumnAccurate(arguments[3], uint64_data_type);
110const auto & data_trials_y = checkAndGetColumn<ColumnVector<UInt64>>(column_trials_y.get())->getData();
111
112static const auto float64_data_type = std::make_shared<DataTypeNumber<Float64>>();
113
114auto column_confidence_level = castColumnAccurate(arguments[4], float64_data_type);
115const auto & data_confidence_level = checkAndGetColumn<ColumnVector<Float64>>(column_confidence_level.get())->getData();
116
117String usevar = checkAndGetColumnConst<ColumnString>(arguments[5].column.get())->getValue<String>();
118
119if (usevar != UNPOOLED && usevar != POOLED)
120throw Exception{ErrorCodes::BAD_ARGUMENTS,
121"The sixth argument {} of function {} must be equal to `pooled` or `unpooled`",
122arguments[5].type->getName(),
123getName()};
124
125const bool is_unpooled = (usevar == UNPOOLED);
126
127auto res_z_statistic = ColumnFloat64::create();
128auto & data_z_statistic = res_z_statistic->getData();
129data_z_statistic.reserve(input_rows_count);
130
131auto res_p_value = ColumnFloat64::create();
132auto & data_p_value = res_p_value->getData();
133data_p_value.reserve(input_rows_count);
134
135auto res_ci_lower = ColumnFloat64::create();
136auto & data_ci_lower = res_ci_lower->getData();
137data_ci_lower.reserve(input_rows_count);
138
139auto res_ci_upper = ColumnFloat64::create();
140auto & data_ci_upper = res_ci_upper->getData();
141data_ci_upper.reserve(input_rows_count);
142
143auto insert_values_into_result = [&data_z_statistic, &data_p_value, &data_ci_lower, &data_ci_upper](
144Float64 z_stat, Float64 p_value, Float64 lower, Float64 upper)
145{
146data_z_statistic.emplace_back(z_stat);
147data_p_value.emplace_back(p_value);
148data_ci_lower.emplace_back(lower);
149data_ci_upper.emplace_back(upper);
150};
151
152static constexpr Float64 nan = std::numeric_limits<Float64>::quiet_NaN();
153
154boost::math::normal_distribution<> nd(0.0, 1.0);
155
156for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
157{
158const UInt64 successes_x = data_successes_x[row_num];
159const UInt64 successes_y = data_successes_y[row_num];
160const UInt64 trials_x = data_trials_x[row_num];
161const UInt64 trials_y = data_trials_y[row_num];
162const Float64 confidence_level = data_confidence_level[row_num];
163
164const Float64 props_x = static_cast<Float64>(successes_x) / trials_x;
165const Float64 props_y = static_cast<Float64>(successes_y) / trials_y;
166const Float64 diff = props_x - props_y;
167const UInt64 trials_total = trials_x + trials_y;
168
169if (successes_x == 0 || successes_y == 0 || successes_x > trials_x || successes_y > trials_y || trials_total == 0
170|| !std::isfinite(confidence_level) || confidence_level < 0.0 || confidence_level > 1.0)
171{
172insert_values_into_result(nan, nan, nan, nan);
173continue;
174}
175
176Float64 se = std::sqrt(props_x * (1.0 - props_x) / trials_x + props_y * (1.0 - props_y) / trials_y);
177
178/// z-statistics
179/// z = \frac{ \bar{p_{1}} - \bar{p_{2}} }{ \sqrt{ \frac{ \bar{p_{1}} \left ( 1 - \bar{p_{1}} \right ) }{ n_{1} } \frac{ \bar{p_{2}} \left ( 1 - \bar{p_{2}} \right ) }{ n_{2} } } }
180Float64 zstat;
181if (is_unpooled)
182{
183zstat = (props_x - props_y) / se;
184}
185else
186{
187UInt64 successes_total = successes_x + successes_y;
188Float64 p_pooled = static_cast<Float64>(successes_total) / trials_total;
189Float64 trials_fact = 1.0 / trials_x + 1.0 / trials_y;
190zstat = diff / std::sqrt(p_pooled * (1.0 - p_pooled) * trials_fact);
191}
192
193if (unlikely(!std::isfinite(zstat)))
194{
195insert_values_into_result(nan, nan, nan, nan);
196continue;
197}
198
199// pvalue
200Float64 pvalue = 0;
201Float64 one_side = 1 - boost::math::cdf(nd, std::abs(zstat));
202pvalue = one_side * 2;
203
204// Confidence intervals
205Float64 d = props_x - props_y;
206Float64 z = -boost::math::quantile(nd, (1.0 - confidence_level) / 2.0);
207Float64 dist = z * se;
208Float64 ci_low = d - dist;
209Float64 ci_high = d + dist;
210
211insert_values_into_result(zstat, pvalue, ci_low, ci_high);
212}
213
214return ColumnTuple::create(
215Columns{std::move(res_z_statistic), std::move(res_p_value), std::move(res_ci_lower), std::move(res_ci_upper)});
216}
217};
218
219
220REGISTER_FUNCTION(ZTest)
221{
222factory.registerFunction<FunctionTwoSampleProportionsZTest>();
223}
224
225}
226