ClickHouse

Форк
0
/
minSampleSize.cpp 
293 строки · 11.2 Кб
1
#include <cfloat>
2
#include <cmath>
3

4
#include <boost/math/distributions/normal.hpp>
5

6
#include <DataTypes/DataTypeTuple.h>
7
#include <DataTypes/DataTypesDecimal.h>
8
#include <DataTypes/DataTypesNumber.h>
9
#include <Columns/ColumnTuple.h>
10
#include <Columns/ColumnsNumber.h>
11
#include <Functions/FunctionFactory.h>
12
#include <Functions/FunctionHelpers.h>
13
#include <Functions/IFunction.h>
14
#include <Functions/castTypeToEither.h>
15
#include <Interpreters/castColumn.h>
16

17

18
namespace DB
19
{
20

21
namespace ErrorCodes
22
{
23
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
24
}
25

26
template <typename Impl>
27
class FunctionMinSampleSize : public IFunction
28
{
29
public:
30
    static constexpr auto name = Impl::name;
31

32
    static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionMinSampleSize<Impl>>(); }
33

34
    String getName() const override { return name; }
35

36
    size_t getNumberOfArguments() const override { return Impl::num_args; }
37
    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override
38
    {
39
        return ColumnNumbers(std::begin(Impl::const_args), std::end(Impl::const_args));
40
    }
41

42
    bool useDefaultImplementationForNulls() const override { return false; }
43
    bool useDefaultImplementationForConstants() const override { return true; }
44
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
45

46
    static DataTypePtr getReturnType()
47
    {
48
        auto float_64_type = std::make_shared<DataTypeNumber<Float64>>();
49

50
        DataTypes types{
51
            float_64_type,
52
            float_64_type,
53
            float_64_type,
54
        };
55

56
        Strings names{
57
            "minimum_sample_size",
58
            "detect_range_lower",
59
            "detect_range_upper",
60
        };
61

62
        return std::make_shared<DataTypeTuple>(std::move(types), std::move(names));
63
    }
64

65
    DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
66
    {
67
        Impl::validateArguments(arguments);
68
        return getReturnType();
69
    }
70

71
    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
72
    {
73
        return Impl::execute(arguments, input_rows_count);
74
    }
75
};
76

77
static bool isBetweenZeroAndOne(Float64 v)
78
{
79
    return v >= 0.0 && v <= 1.0 && fabs(v - 0.0) >= DBL_EPSILON && fabs(v - 1.0) >= DBL_EPSILON;
80
}
81

82
struct ContinuousImpl
83
{
84
    static constexpr auto name = "minSampleSizeContinuous";
85
    static constexpr size_t num_args = 5;
86
    static constexpr size_t const_args[] = {2, 3, 4};
87

88
    static void validateArguments(const DataTypes & arguments)
89
    {
90
        for (size_t i = 0; i < arguments.size(); ++i)
91
        {
92
            if (!isNativeNumber(arguments[i]))
93
            {
94
                throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The {}th Argument of function {} must be a number.", i + 1, name);
95
            }
96
        }
97
    }
98

99
    static ColumnPtr execute(const ColumnsWithTypeAndName & arguments, size_t input_rows_count)
100
    {
101
        auto float_64_type = std::make_shared<DataTypeFloat64>();
102
        auto baseline_argument = arguments[0];
103
        baseline_argument.column = baseline_argument.column->convertToFullColumnIfConst();
104
        auto baseline_column_untyped = castColumnAccurate(baseline_argument, float_64_type);
105
        const auto * baseline_column = checkAndGetColumn<ColumnVector<Float64>>(*baseline_column_untyped);
106
        const auto & baseline_column_data = baseline_column->getData();
107

108
        auto sigma_argument = arguments[1];
109
        sigma_argument.column = sigma_argument.column->convertToFullColumnIfConst();
110
        auto sigma_column_untyped = castColumnAccurate(sigma_argument, float_64_type);
111
        const auto * sigma_column = checkAndGetColumn<ColumnVector<Float64>>(*sigma_column_untyped);
112
        const auto & sigma_column_data = sigma_column->getData();
113

114
        const IColumn & col_mde = *arguments[2].column;
115
        const IColumn & col_power = *arguments[3].column;
116
        const IColumn & col_alpha = *arguments[4].column;
117

118
        auto res_min_sample_size = ColumnFloat64::create();
119
        auto & data_min_sample_size = res_min_sample_size->getData();
120
        data_min_sample_size.reserve(input_rows_count);
121

122
        auto res_detect_lower = ColumnFloat64::create();
123
        auto & data_detect_lower = res_detect_lower->getData();
124
        data_detect_lower.reserve(input_rows_count);
125

126
        auto res_detect_upper = ColumnFloat64::create();
127
        auto & data_detect_upper = res_detect_upper->getData();
128
        data_detect_upper.reserve(input_rows_count);
129

130
        /// Minimal Detectable Effect
131
        const Float64 mde = col_mde.getFloat64(0);
132
        /// Sufficient statistical power to detect a treatment effect
133
        const Float64 power = col_power.getFloat64(0);
134
        /// Significance level
135
        const Float64 alpha = col_alpha.getFloat64(0);
136

137
        boost::math::normal_distribution<> nd(0.0, 1.0);
138

139
        for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
140
        {
141
            /// Mean of control-metric
142
            Float64 baseline = baseline_column_data[row_num];
143
            /// Standard deviation of conrol-metric
144
            Float64 sigma = sigma_column_data[row_num];
145

146
            if (!std::isfinite(baseline) || !std::isfinite(sigma) || !isBetweenZeroAndOne(mde) || !isBetweenZeroAndOne(power)
147
                || !isBetweenZeroAndOne(alpha))
148
            {
149
                data_min_sample_size.emplace_back(std::numeric_limits<Float64>::quiet_NaN());
150
                data_detect_lower.emplace_back(std::numeric_limits<Float64>::quiet_NaN());
151
                data_detect_upper.emplace_back(std::numeric_limits<Float64>::quiet_NaN());
152
                continue;
153
            }
154

155
            Float64 delta = baseline * mde;
156

157
            using namespace boost::math;
158
            /// https://towardsdatascience.com/required-sample-size-for-a-b-testing-6f6608dd330a
159
            /// \frac{2\sigma^{2} * (Z_{1 - alpha /2} + Z_{power})^{2}}{\Delta^{2}}
160
            Float64 min_sample_size
161
                = 2 * std::pow(sigma, 2) * std::pow(quantile(nd, 1.0 - alpha / 2) + quantile(nd, power), 2) / std::pow(delta, 2);
162

163
            data_min_sample_size.emplace_back(min_sample_size);
164
            data_detect_lower.emplace_back(baseline - delta);
165
            data_detect_upper.emplace_back(baseline + delta);
166
        }
167

168
        return ColumnTuple::create(Columns{std::move(res_min_sample_size), std::move(res_detect_lower), std::move(res_detect_upper)});
169
    }
170
};
171

172

173
struct ConversionImpl
174
{
175
    static constexpr auto name = "minSampleSizeConversion";
176
    static constexpr size_t num_args = 4;
177
    static constexpr size_t const_args[] = {1, 2, 3};
178

179
    static void validateArguments(const DataTypes & arguments)
180
    {
181
        size_t arguments_size = arguments.size();
182
        for (size_t i = 0; i < arguments_size; ++i)
183
        {
184
            if (!isFloat(arguments[i]))
185
            {
186
                throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The {}th argument of function {} must be a float.", i + 1, name);
187
            }
188
        }
189
    }
190

191
    static ColumnPtr execute(const ColumnsWithTypeAndName & arguments, size_t input_rows_count)
192
    {
193
        auto first_argument_column = castColumnAccurate(arguments[0], std::make_shared<DataTypeFloat64>());
194

195
        if (const ColumnConst * const col_p1_const = checkAndGetColumnConst<ColumnVector<Float64>>(first_argument_column.get()))
196
        {
197
            const Float64 left_value = col_p1_const->template getValue<Float64>();
198
            return process<true>(arguments, &left_value, input_rows_count);
199
        }
200
        else if (const ColumnVector<Float64> * const col_p1 = checkAndGetColumn<ColumnVector<Float64>>(first_argument_column.get()))
201
        {
202
            return process<false>(arguments, col_p1->getData().data(), input_rows_count);
203
        }
204
        else
205
        {
206
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The first argument of function {} must be a float.", name);
207
        }
208
    }
209

210
    template <bool const_p1>
211
    static ColumnPtr process(const ColumnsWithTypeAndName & arguments, const Float64 * col_p1, const size_t input_rows_count)
212
    {
213
        const IColumn & col_mde = *arguments[1].column;
214
        const IColumn & col_power = *arguments[2].column;
215
        const IColumn & col_alpha = *arguments[3].column;
216

217
        auto res_min_sample_size = ColumnFloat64::create();
218
        auto & data_min_sample_size = res_min_sample_size->getData();
219
        data_min_sample_size.reserve(input_rows_count);
220

221
        auto res_detect_lower = ColumnFloat64::create();
222
        auto & data_detect_lower = res_detect_lower->getData();
223
        data_detect_lower.reserve(input_rows_count);
224

225
        auto res_detect_upper = ColumnFloat64::create();
226
        auto & data_detect_upper = res_detect_upper->getData();
227
        data_detect_upper.reserve(input_rows_count);
228

229
        /// Minimal Detectable Effect
230
        const Float64 mde = col_mde.getFloat64(0);
231
        /// Sufficient statistical power to detect a treatment effect
232
        const Float64 power = col_power.getFloat64(0);
233
        /// Significance level
234
        const Float64 alpha = col_alpha.getFloat64(0);
235

236
        boost::math::normal_distribution<> nd(0.0, 1.0);
237

238
        for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
239
        {
240
            /// Proportion of control-metric
241
            Float64 p1;
242

243
            if constexpr (const_p1)
244
            {
245
                p1 = col_p1[0];
246
            }
247
            else if constexpr (!const_p1)
248
            {
249
                p1 = col_p1[row_num];
250
            }
251

252
            if (!std::isfinite(p1) || !isBetweenZeroAndOne(mde) || !isBetweenZeroAndOne(power) || !isBetweenZeroAndOne(alpha))
253
            {
254
                data_min_sample_size.emplace_back(std::numeric_limits<Float64>::quiet_NaN());
255
                data_detect_lower.emplace_back(std::numeric_limits<Float64>::quiet_NaN());
256
                data_detect_upper.emplace_back(std::numeric_limits<Float64>::quiet_NaN());
257
                continue;
258
            }
259

260
            Float64 q1 = 1.0 - p1;
261
            Float64 p2 = p1 + mde;
262
            Float64 q2 = 1.0 - p2;
263
            Float64 p_bar = (p1 + p2) / 2.0;
264
            Float64 q_bar = 1.0 - p_bar;
265

266
            using namespace boost::math;
267
            /// https://towardsdatascience.com/required-sample-size-for-a-b-testing-6f6608dd330a
268
            /// \frac{(Z_{1-alpha/2} * \sqrt{2*\bar{p}*\bar{q}} + Z_{power} * \sqrt{p1*q1+p2*q2})^{2}}{\Delta^{2}}
269
            Float64 min_sample_size
270
                = std::pow(
271
                      quantile(nd, 1.0 - alpha / 2.0) * std::sqrt(2.0 * p_bar * q_bar) + quantile(nd, power) * std::sqrt(p1 * q1 + p2 * q2),
272
                      2)
273
                / std::pow(mde, 2);
274

275
            data_min_sample_size.emplace_back(min_sample_size);
276
            data_detect_lower.emplace_back(p1 - mde);
277
            data_detect_upper.emplace_back(p1 + mde);
278
        }
279

280
        return ColumnTuple::create(Columns{std::move(res_min_sample_size), std::move(res_detect_lower), std::move(res_detect_upper)});
281
    }
282
};
283

284

285
REGISTER_FUNCTION(MinSampleSize)
286
{
287
    factory.registerFunction<FunctionMinSampleSize<ContinuousImpl>>();
288
    /// Needed for backward compatibility
289
    factory.registerAlias("minSampleSizeContinous", FunctionMinSampleSize<ContinuousImpl>::name);
290
    factory.registerFunction<FunctionMinSampleSize<ConversionImpl>>();
291
}
292

293
}
294

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.