ClickHouse

Форк
0
/
AggregateFunctionKolmogorovSmirnovTest.cpp 
356 строк · 12.1 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/FactoryHelpers.h>
3
#include <AggregateFunctions/IAggregateFunction.h>
4
#include <AggregateFunctions/StatCommon.h>
5
#include <Columns/ColumnVector.h>
6
#include <Columns/ColumnTuple.h>
7
#include <Common/Exception.h>
8
#include <Common/assert_cast.h>
9
#include <Common/PODArray_fwd.h>
10
#include <DataTypes/DataTypeNullable.h>
11
#include <DataTypes/DataTypesNumber.h>
12
#include <DataTypes/DataTypeTuple.h>
13
#include <IO/ReadHelpers.h>
14

15

16
namespace ErrorCodes
17
{
18
    extern const int NOT_IMPLEMENTED;
19
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
20
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
21
    extern const int BAD_ARGUMENTS;
22
}
23

24
namespace DB
25
{
26

27
struct Settings;
28

29
namespace
30
{
31

32
struct KolmogorovSmirnov : public StatisticalSample<Float64, Float64>
33
{
34
    enum class Alternative
35
    {
36
        TwoSided,
37
        Less,
38
        Greater
39
    };
40

41
    std::pair<Float64, Float64> getResult(Alternative alternative, String method)
42
    {
43
        ::sort(x.begin(), x.end());
44
        ::sort(y.begin(), y.end());
45

46
        Float64 max_s = std::numeric_limits<Float64>::min();
47
        Float64 min_s = std::numeric_limits<Float64>::max();
48
        Float64 now_s = 0;
49
        UInt64 pos_x = 0;
50
        UInt64 pos_y = 0;
51
        UInt64 pos_tmp;
52
        UInt64 n1 = x.size();
53
        UInt64 n2 = y.size();
54

55
        const Float64 n1_d = 1. / n1;
56
        const Float64 n2_d = 1. / n2;
57
        const Float64 tol = 1e-7;
58

59
        // reference: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test
60
        while (pos_x < x.size() && pos_y < y.size())
61
        {
62
            if (likely(fabs(x[pos_x] - y[pos_y]) >= tol))
63
            {
64
                if (x[pos_x] < y[pos_y])
65
                {
66
                    now_s += n1_d;
67
                    ++pos_x;
68
                }
69
                else
70
                {
71
                    now_s -= n2_d;
72
                    ++pos_y;
73
                }
74
            }
75
            else
76
            {
77
                pos_tmp = pos_x + 1;
78
                while (pos_tmp < x.size() && unlikely(fabs(x[pos_tmp] - x[pos_x]) <= tol))
79
                    pos_tmp++;
80
                now_s += n1_d * (pos_tmp - pos_x);
81
                pos_x = pos_tmp;
82
                pos_tmp = pos_y + 1;
83
                while (pos_tmp < y.size() && unlikely(fabs(y[pos_tmp] - y[pos_y]) <= tol))
84
                    pos_tmp++;
85
                now_s -= n2_d * (pos_tmp - pos_y);
86
                pos_y = pos_tmp;
87
            }
88
            max_s = std::max(max_s, now_s);
89
            min_s = std::min(min_s, now_s);
90
        }
91
        now_s += n1_d * (x.size() - pos_x) - n2_d * (y.size() - pos_y);
92
        min_s = std::min(min_s, now_s);
93
        max_s = std::max(max_s, now_s);
94

95
        Float64 d = 0;
96
        if (alternative == Alternative::TwoSided)
97
            d = std::max(std::abs(max_s), std::abs(min_s));
98
        else if (alternative == Alternative::Less)
99
            d = -min_s;
100
        else if (alternative == Alternative::Greater)
101
            d = max_s;
102

103
        UInt64 g = std::__gcd(n1, n2);
104
        UInt64 nx_g = n1 / g;
105
        UInt64 ny_g = n2 / g;
106

107
        if (method == "auto")
108
            method = std::max(n1, n2) <= 10000 ? "exact" : "asymptotic";
109
        else if (method == "exact" && nx_g >= std::numeric_limits<Int32>::max() / ny_g)
110
            method = "asymptotic";
111

112
        Float64 p_value = std::numeric_limits<Float64>::infinity();
113

114
        if (method == "exact")
115
        {
116
            /* reference:
117
             * Gunar Schröer and Dietrich Trenkler
118
             * Exact and Randomization Distributions of Kolmogorov-Smirnov, Tests for Two or Three Samples
119
             *
120
             * and
121
             *
122
             * Thomas Viehmann
123
             * Numerically more stable computation of the p-values for the two-sample Kolmogorov-Smirnov test
124
             */
125
            if (n2 > n1)
126
                std::swap(n1, n2);
127

128
            const Float64 f_n1 = static_cast<Float64>(n1);
129
            const Float64 f_n2 = static_cast<Float64>(n2);
130
            const Float64 k_d = (0.5 + floor(d * f_n2 * f_n1 - tol)) / (f_n2 * f_n1);
131
            PaddedPODArray<Float64> c(n1 + 1);
132

133
            auto check = alternative == Alternative::TwoSided ?
134
                         [](const Float64 & q, const Float64 & r, const Float64 & s) { return fabs(r - s) >= q; }
135
                       : [](const Float64 & q, const Float64 & r, const Float64 & s) { return r - s >= q; };
136

137
            c[0] = 0;
138
            for (UInt64 j = 1; j <= n1; j++)
139
                if (check(k_d, 0., j / f_n1))
140
                    c[j] = 1.;
141
                else
142
                    c[j] = c[j - 1];
143

144
            for (UInt64 i = 1; i <= n2; i++)
145
            {
146
                if (check(k_d, i / f_n2, 0.))
147
                    c[0] = 1.;
148
                for (UInt64 j = 1; j <= n1; j++)
149
                    if (check(k_d, i / f_n2, j / f_n1))
150
                        c[j] = 1.;
151
                    else
152
                    {
153
                        Float64 v = i / static_cast<Float64>(i + j);
154
                        Float64 w = j / static_cast<Float64>(i + j);
155
                        c[j] = v * c[j] + w * c[j - 1];
156
                    }
157
            }
158
            p_value = c[n1];
159
        }
160
        else if (method == "asymp" || method == "asymptotic")
161
        {
162
            Float64 n = std::min(n1, n2);
163
            Float64 m = std::max(n1, n2);
164
            Float64 p = sqrt((n * m) / (n + m)) * d;
165

166
            if (alternative == Alternative::TwoSided)
167
            {
168
                /* reference:
169
                 * J.DURBIN
170
                 * Distribution theory for tests based on the sample distribution function
171
                 */
172
                Float64 new_val, old_val, s, w, z;
173
                UInt64 k_max = static_cast<UInt64>(sqrt(2 - log(tol)));
174

175
                if (p < 1)
176
                {
177
                    z = - (M_PI_2 * M_PI_4) / (p * p);
178
                    w = log(p);
179
                    s = 0;
180
                    for (UInt64 k = 1; k < k_max; k += 2)
181
                        s += exp(k * k * z - w);
182
                    p = s / 0.398942280401432677939946059934;
183
                }
184
                else
185
                {
186
                    z = -2 * p * p;
187
                    s = -1;
188
                    UInt64 k = 1;
189
                    old_val = 0;
190
                    new_val = 1;
191
                    while (fabs(old_val - new_val) > tol)
192
                    {
193
                        old_val = new_val;
194
                        new_val += 2 * s * exp(z * k * k);
195
                        s *= -1;
196
                        k++;
197
                    }
198
                    p = new_val;
199
                }
200
                p_value = 1 - p;
201
            }
202
            else
203
            {
204
                /* reference:
205
                 * J. L. HODGES, Jr
206
                 * The significance probability of the Smirnov two-sample test
207
                 */
208

209
                // Use Hodges' suggested approximation Eqn 5.3
210
                // Requires m to be the larger of (n1, n2)
211
                Float64 expt = -2 * p * p - 2 * p * (m + 2 * n) / sqrt(m * n * (m + n)) / 3.0;
212
                p_value = exp(expt);
213
            }
214
        }
215
        return {d, p_value};
216
    }
217

218
};
219

220
class AggregateFunctionKolmogorovSmirnov final:
221
    public IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov>
222
{
223
private:
224
    using Alternative = typename KolmogorovSmirnov::Alternative;
225
    Alternative alternative = Alternative::TwoSided;
226
    String method = "auto";
227

228
public:
229
    explicit AggregateFunctionKolmogorovSmirnov(const DataTypes & arguments, const Array & params)
230
        : IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov> ({arguments}, {}, createResultType())
231
    {
232
        if (params.size() > 2)
233
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());
234

235
        if (params.empty())
236
            return;
237

238
        if (params[0].getType() != Field::Types::String)
239
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());
240

241
        const auto & param = params[0].get<String>();
242
        if (param == "two-sided")
243
            alternative = Alternative::TwoSided;
244
        else if (param == "less")
245
            alternative = Alternative::Less;
246
        else if (param == "greater")
247
            alternative = Alternative::Greater;
248
        else
249
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "
250
                    "It must be one of: 'two-sided', 'less', 'greater'", getName());
251

252
        if (params.size() != 2)
253
            return;
254

255
        if (params[1].getType() != Field::Types::String)
256
                throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a String", getName());
257

258
        method = params[1].get<String>();
259
        if (method != "auto" && method != "exact" && method != "asymp" && method != "asymptotic")
260
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown method in aggregate function {}. "
261
                    "It must be one of: 'auto', 'exact', 'asymp' (or 'asymptotic')", getName());
262
    }
263

264
    String getName() const override
265
    {
266
        return "kolmogorovSmirnovTest";
267
    }
268

269
    bool allocatesMemoryInArena() const override { return true; }
270

271
    static DataTypePtr createResultType()
272
    {
273
        DataTypes types
274
        {
275
            std::make_shared<DataTypeNumber<Float64>>(),
276
            std::make_shared<DataTypeNumber<Float64>>(),
277
        };
278

279
        Strings names
280
        {
281
            "d_statistic",
282
            "p_value"
283
        };
284

285
        return std::make_shared<DataTypeTuple>(
286
            std::move(types),
287
            std::move(names)
288
        );
289
    }
290

291
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
292
    {
293
        Float64 value = columns[0]->getFloat64(row_num);
294
        UInt8 is_second = columns[1]->getUInt(row_num);
295
        if (is_second)
296
            data(place).addY(value, arena);
297
        else
298
            data(place).addX(value, arena);
299
    }
300

301
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
302
    {
303
        data(place).merge(data(rhs), arena);
304
    }
305

306
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
307
    {
308
        data(place).write(buf);
309
    }
310

311
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
312
    {
313
        data(place).read(buf, arena);
314
    }
315

316
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
317
    {
318
        if (!data(place).size_x || !data(place).size_y)
319
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());
320

321
        auto [d_statistic, p_value] = data(place).getResult(alternative, method);
322

323
        /// Because p-value is a probability.
324
        p_value = std::min(1.0, std::max(0.0, p_value));
325

326
        auto & column_tuple = assert_cast<ColumnTuple &>(to);
327
        auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
328
        auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
329

330
        column_stat.getData().push_back(d_statistic);
331
        column_value.getData().push_back(p_value);
332
    }
333

334
};
335

336

337
AggregateFunctionPtr createAggregateFunctionKolmogorovSmirnovTest(
338
    const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
339
{
340
    assertBinary(name, argument_types);
341

342
    if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
343
        throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Aggregate function {} only supports numerical types", name);
344

345
    return std::make_shared<AggregateFunctionKolmogorovSmirnov>(argument_types, parameters);
346
}
347

348

349
}
350

351
void registerAggregateFunctionKolmogorovSmirnovTest(AggregateFunctionFactory & factory)
352
{
353
    factory.registerFunction("kolmogorovSmirnovTest", createAggregateFunctionKolmogorovSmirnovTest, AggregateFunctionFactory::CaseInsensitive);
354
}
355

356
}
357

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.