ClickHouse
293 строки · 11.2 Кб
1#include <cfloat>2#include <cmath>3
4#include <boost/math/distributions/normal.hpp>5
6#include <DataTypes/DataTypeTuple.h>7#include <DataTypes/DataTypesDecimal.h>8#include <DataTypes/DataTypesNumber.h>9#include <Columns/ColumnTuple.h>10#include <Columns/ColumnsNumber.h>11#include <Functions/FunctionFactory.h>12#include <Functions/FunctionHelpers.h>13#include <Functions/IFunction.h>14#include <Functions/castTypeToEither.h>15#include <Interpreters/castColumn.h>16
17
18namespace DB19{
20
21namespace ErrorCodes22{
23extern const int ILLEGAL_TYPE_OF_ARGUMENT;24}
25
26template <typename Impl>27class FunctionMinSampleSize : public IFunction28{
29public:30static constexpr auto name = Impl::name;31
32static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionMinSampleSize<Impl>>(); }33
34String getName() const override { return name; }35
36size_t getNumberOfArguments() const override { return Impl::num_args; }37ColumnNumbers getArgumentsThatAreAlwaysConstant() const override38{39return ColumnNumbers(std::begin(Impl::const_args), std::end(Impl::const_args));40}41
42bool useDefaultImplementationForNulls() const override { return false; }43bool useDefaultImplementationForConstants() const override { return true; }44bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }45
46static DataTypePtr getReturnType()47{48auto float_64_type = std::make_shared<DataTypeNumber<Float64>>();49
50DataTypes types{51float_64_type,52float_64_type,53float_64_type,54};55
56Strings names{57"minimum_sample_size",58"detect_range_lower",59"detect_range_upper",60};61
62return std::make_shared<DataTypeTuple>(std::move(types), std::move(names));63}64
65DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override66{67Impl::validateArguments(arguments);68return getReturnType();69}70
71ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override72{73return Impl::execute(arguments, input_rows_count);74}75};76
77static bool isBetweenZeroAndOne(Float64 v)78{
79return v >= 0.0 && v <= 1.0 && fabs(v - 0.0) >= DBL_EPSILON && fabs(v - 1.0) >= DBL_EPSILON;80}
81
82struct ContinuousImpl83{
84static constexpr auto name = "minSampleSizeContinuous";85static constexpr size_t num_args = 5;86static constexpr size_t const_args[] = {2, 3, 4};87
88static void validateArguments(const DataTypes & arguments)89{90for (size_t i = 0; i < arguments.size(); ++i)91{92if (!isNativeNumber(arguments[i]))93{94throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The {}th Argument of function {} must be a number.", i + 1, name);95}96}97}98
99static ColumnPtr execute(const ColumnsWithTypeAndName & arguments, size_t input_rows_count)100{101auto float_64_type = std::make_shared<DataTypeFloat64>();102auto baseline_argument = arguments[0];103baseline_argument.column = baseline_argument.column->convertToFullColumnIfConst();104auto baseline_column_untyped = castColumnAccurate(baseline_argument, float_64_type);105const auto * baseline_column = checkAndGetColumn<ColumnVector<Float64>>(*baseline_column_untyped);106const auto & baseline_column_data = baseline_column->getData();107
108auto sigma_argument = arguments[1];109sigma_argument.column = sigma_argument.column->convertToFullColumnIfConst();110auto sigma_column_untyped = castColumnAccurate(sigma_argument, float_64_type);111const auto * sigma_column = checkAndGetColumn<ColumnVector<Float64>>(*sigma_column_untyped);112const auto & sigma_column_data = sigma_column->getData();113
114const IColumn & col_mde = *arguments[2].column;115const IColumn & col_power = *arguments[3].column;116const IColumn & col_alpha = *arguments[4].column;117
118auto res_min_sample_size = ColumnFloat64::create();119auto & data_min_sample_size = res_min_sample_size->getData();120data_min_sample_size.reserve(input_rows_count);121
122auto res_detect_lower = ColumnFloat64::create();123auto & data_detect_lower = res_detect_lower->getData();124data_detect_lower.reserve(input_rows_count);125
126auto res_detect_upper = ColumnFloat64::create();127auto & data_detect_upper = res_detect_upper->getData();128data_detect_upper.reserve(input_rows_count);129
130/// Minimal Detectable Effect131const Float64 mde = col_mde.getFloat64(0);132/// Sufficient statistical power to detect a treatment effect133const Float64 power = col_power.getFloat64(0);134/// Significance level135const Float64 alpha = col_alpha.getFloat64(0);136
137boost::math::normal_distribution<> nd(0.0, 1.0);138
139for (size_t row_num = 0; row_num < input_rows_count; ++row_num)140{141/// Mean of control-metric142Float64 baseline = baseline_column_data[row_num];143/// Standard deviation of conrol-metric144Float64 sigma = sigma_column_data[row_num];145
146if (!std::isfinite(baseline) || !std::isfinite(sigma) || !isBetweenZeroAndOne(mde) || !isBetweenZeroAndOne(power)147|| !isBetweenZeroAndOne(alpha))148{149data_min_sample_size.emplace_back(std::numeric_limits<Float64>::quiet_NaN());150data_detect_lower.emplace_back(std::numeric_limits<Float64>::quiet_NaN());151data_detect_upper.emplace_back(std::numeric_limits<Float64>::quiet_NaN());152continue;153}154
155Float64 delta = baseline * mde;156
157using namespace boost::math;158/// https://towardsdatascience.com/required-sample-size-for-a-b-testing-6f6608dd330a159/// \frac{2\sigma^{2} * (Z_{1 - alpha /2} + Z_{power})^{2}}{\Delta^{2}}160Float64 min_sample_size
161= 2 * std::pow(sigma, 2) * std::pow(quantile(nd, 1.0 - alpha / 2) + quantile(nd, power), 2) / std::pow(delta, 2);162
163data_min_sample_size.emplace_back(min_sample_size);164data_detect_lower.emplace_back(baseline - delta);165data_detect_upper.emplace_back(baseline + delta);166}167
168return ColumnTuple::create(Columns{std::move(res_min_sample_size), std::move(res_detect_lower), std::move(res_detect_upper)});169}170};171
172
173struct ConversionImpl174{
175static constexpr auto name = "minSampleSizeConversion";176static constexpr size_t num_args = 4;177static constexpr size_t const_args[] = {1, 2, 3};178
179static void validateArguments(const DataTypes & arguments)180{181size_t arguments_size = arguments.size();182for (size_t i = 0; i < arguments_size; ++i)183{184if (!isFloat(arguments[i]))185{186throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The {}th argument of function {} must be a float.", i + 1, name);187}188}189}190
191static ColumnPtr execute(const ColumnsWithTypeAndName & arguments, size_t input_rows_count)192{193auto first_argument_column = castColumnAccurate(arguments[0], std::make_shared<DataTypeFloat64>());194
195if (const ColumnConst * const col_p1_const = checkAndGetColumnConst<ColumnVector<Float64>>(first_argument_column.get()))196{197const Float64 left_value = col_p1_const->template getValue<Float64>();198return process<true>(arguments, &left_value, input_rows_count);199}200else if (const ColumnVector<Float64> * const col_p1 = checkAndGetColumn<ColumnVector<Float64>>(first_argument_column.get()))201{202return process<false>(arguments, col_p1->getData().data(), input_rows_count);203}204else205{206throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The first argument of function {} must be a float.", name);207}208}209
210template <bool const_p1>211static ColumnPtr process(const ColumnsWithTypeAndName & arguments, const Float64 * col_p1, const size_t input_rows_count)212{213const IColumn & col_mde = *arguments[1].column;214const IColumn & col_power = *arguments[2].column;215const IColumn & col_alpha = *arguments[3].column;216
217auto res_min_sample_size = ColumnFloat64::create();218auto & data_min_sample_size = res_min_sample_size->getData();219data_min_sample_size.reserve(input_rows_count);220
221auto res_detect_lower = ColumnFloat64::create();222auto & data_detect_lower = res_detect_lower->getData();223data_detect_lower.reserve(input_rows_count);224
225auto res_detect_upper = ColumnFloat64::create();226auto & data_detect_upper = res_detect_upper->getData();227data_detect_upper.reserve(input_rows_count);228
229/// Minimal Detectable Effect230const Float64 mde = col_mde.getFloat64(0);231/// Sufficient statistical power to detect a treatment effect232const Float64 power = col_power.getFloat64(0);233/// Significance level234const Float64 alpha = col_alpha.getFloat64(0);235
236boost::math::normal_distribution<> nd(0.0, 1.0);237
238for (size_t row_num = 0; row_num < input_rows_count; ++row_num)239{240/// Proportion of control-metric241Float64 p1;242
243if constexpr (const_p1)244{245p1 = col_p1[0];246}247else if constexpr (!const_p1)248{249p1 = col_p1[row_num];250}251
252if (!std::isfinite(p1) || !isBetweenZeroAndOne(mde) || !isBetweenZeroAndOne(power) || !isBetweenZeroAndOne(alpha))253{254data_min_sample_size.emplace_back(std::numeric_limits<Float64>::quiet_NaN());255data_detect_lower.emplace_back(std::numeric_limits<Float64>::quiet_NaN());256data_detect_upper.emplace_back(std::numeric_limits<Float64>::quiet_NaN());257continue;258}259
260Float64 q1 = 1.0 - p1;261Float64 p2 = p1 + mde;262Float64 q2 = 1.0 - p2;263Float64 p_bar = (p1 + p2) / 2.0;264Float64 q_bar = 1.0 - p_bar;265
266using namespace boost::math;267/// https://towardsdatascience.com/required-sample-size-for-a-b-testing-6f6608dd330a268/// \frac{(Z_{1-alpha/2} * \sqrt{2*\bar{p}*\bar{q}} + Z_{power} * \sqrt{p1*q1+p2*q2})^{2}}{\Delta^{2}}269Float64 min_sample_size
270= std::pow(271quantile(nd, 1.0 - alpha / 2.0) * std::sqrt(2.0 * p_bar * q_bar) + quantile(nd, power) * std::sqrt(p1 * q1 + p2 * q2),2722)273/ std::pow(mde, 2);274
275data_min_sample_size.emplace_back(min_sample_size);276data_detect_lower.emplace_back(p1 - mde);277data_detect_upper.emplace_back(p1 + mde);278}279
280return ColumnTuple::create(Columns{std::move(res_min_sample_size), std::move(res_detect_lower), std::move(res_detect_upper)});281}282};283
284
285REGISTER_FUNCTION(MinSampleSize)286{
287factory.registerFunction<FunctionMinSampleSize<ContinuousImpl>>();288/// Needed for backward compatibility289factory.registerAlias("minSampleSizeContinous", FunctionMinSampleSize<ContinuousImpl>::name);290factory.registerFunction<FunctionMinSampleSize<ConversionImpl>>();291}
292
293}
294