ClickHouse
356 строк · 12.1 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/FactoryHelpers.h>
3#include <AggregateFunctions/IAggregateFunction.h>
4#include <AggregateFunctions/StatCommon.h>
5#include <Columns/ColumnVector.h>
6#include <Columns/ColumnTuple.h>
7#include <Common/Exception.h>
8#include <Common/assert_cast.h>
9#include <Common/PODArray_fwd.h>
10#include <DataTypes/DataTypeNullable.h>
11#include <DataTypes/DataTypesNumber.h>
12#include <DataTypes/DataTypeTuple.h>
13#include <IO/ReadHelpers.h>
14
15
16namespace ErrorCodes
17{
18extern const int NOT_IMPLEMENTED;
19extern const int ILLEGAL_TYPE_OF_ARGUMENT;
20extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
21extern const int BAD_ARGUMENTS;
22}
23
24namespace DB
25{
26
27struct Settings;
28
29namespace
30{
31
32struct KolmogorovSmirnov : public StatisticalSample<Float64, Float64>
33{
34enum class Alternative
35{
36TwoSided,
37Less,
38Greater
39};
40
41std::pair<Float64, Float64> getResult(Alternative alternative, String method)
42{
43::sort(x.begin(), x.end());
44::sort(y.begin(), y.end());
45
46Float64 max_s = std::numeric_limits<Float64>::min();
47Float64 min_s = std::numeric_limits<Float64>::max();
48Float64 now_s = 0;
49UInt64 pos_x = 0;
50UInt64 pos_y = 0;
51UInt64 pos_tmp;
52UInt64 n1 = x.size();
53UInt64 n2 = y.size();
54
55const Float64 n1_d = 1. / n1;
56const Float64 n2_d = 1. / n2;
57const Float64 tol = 1e-7;
58
59// reference: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test
60while (pos_x < x.size() && pos_y < y.size())
61{
62if (likely(fabs(x[pos_x] - y[pos_y]) >= tol))
63{
64if (x[pos_x] < y[pos_y])
65{
66now_s += n1_d;
67++pos_x;
68}
69else
70{
71now_s -= n2_d;
72++pos_y;
73}
74}
75else
76{
77pos_tmp = pos_x + 1;
78while (pos_tmp < x.size() && unlikely(fabs(x[pos_tmp] - x[pos_x]) <= tol))
79pos_tmp++;
80now_s += n1_d * (pos_tmp - pos_x);
81pos_x = pos_tmp;
82pos_tmp = pos_y + 1;
83while (pos_tmp < y.size() && unlikely(fabs(y[pos_tmp] - y[pos_y]) <= tol))
84pos_tmp++;
85now_s -= n2_d * (pos_tmp - pos_y);
86pos_y = pos_tmp;
87}
88max_s = std::max(max_s, now_s);
89min_s = std::min(min_s, now_s);
90}
91now_s += n1_d * (x.size() - pos_x) - n2_d * (y.size() - pos_y);
92min_s = std::min(min_s, now_s);
93max_s = std::max(max_s, now_s);
94
95Float64 d = 0;
96if (alternative == Alternative::TwoSided)
97d = std::max(std::abs(max_s), std::abs(min_s));
98else if (alternative == Alternative::Less)
99d = -min_s;
100else if (alternative == Alternative::Greater)
101d = max_s;
102
103UInt64 g = std::__gcd(n1, n2);
104UInt64 nx_g = n1 / g;
105UInt64 ny_g = n2 / g;
106
107if (method == "auto")
108method = std::max(n1, n2) <= 10000 ? "exact" : "asymptotic";
109else if (method == "exact" && nx_g >= std::numeric_limits<Int32>::max() / ny_g)
110method = "asymptotic";
111
112Float64 p_value = std::numeric_limits<Float64>::infinity();
113
114if (method == "exact")
115{
116/* reference:
117* Gunar Schröer and Dietrich Trenkler
118* Exact and Randomization Distributions of Kolmogorov-Smirnov, Tests for Two or Three Samples
119*
120* and
121*
122* Thomas Viehmann
123* Numerically more stable computation of the p-values for the two-sample Kolmogorov-Smirnov test
124*/
125if (n2 > n1)
126std::swap(n1, n2);
127
128const Float64 f_n1 = static_cast<Float64>(n1);
129const Float64 f_n2 = static_cast<Float64>(n2);
130const Float64 k_d = (0.5 + floor(d * f_n2 * f_n1 - tol)) / (f_n2 * f_n1);
131PaddedPODArray<Float64> c(n1 + 1);
132
133auto check = alternative == Alternative::TwoSided ?
134[](const Float64 & q, const Float64 & r, const Float64 & s) { return fabs(r - s) >= q; }
135: [](const Float64 & q, const Float64 & r, const Float64 & s) { return r - s >= q; };
136
137c[0] = 0;
138for (UInt64 j = 1; j <= n1; j++)
139if (check(k_d, 0., j / f_n1))
140c[j] = 1.;
141else
142c[j] = c[j - 1];
143
144for (UInt64 i = 1; i <= n2; i++)
145{
146if (check(k_d, i / f_n2, 0.))
147c[0] = 1.;
148for (UInt64 j = 1; j <= n1; j++)
149if (check(k_d, i / f_n2, j / f_n1))
150c[j] = 1.;
151else
152{
153Float64 v = i / static_cast<Float64>(i + j);
154Float64 w = j / static_cast<Float64>(i + j);
155c[j] = v * c[j] + w * c[j - 1];
156}
157}
158p_value = c[n1];
159}
160else if (method == "asymp" || method == "asymptotic")
161{
162Float64 n = std::min(n1, n2);
163Float64 m = std::max(n1, n2);
164Float64 p = sqrt((n * m) / (n + m)) * d;
165
166if (alternative == Alternative::TwoSided)
167{
168/* reference:
169* J.DURBIN
170* Distribution theory for tests based on the sample distribution function
171*/
172Float64 new_val, old_val, s, w, z;
173UInt64 k_max = static_cast<UInt64>(sqrt(2 - log(tol)));
174
175if (p < 1)
176{
177z = - (M_PI_2 * M_PI_4) / (p * p);
178w = log(p);
179s = 0;
180for (UInt64 k = 1; k < k_max; k += 2)
181s += exp(k * k * z - w);
182p = s / 0.398942280401432677939946059934;
183}
184else
185{
186z = -2 * p * p;
187s = -1;
188UInt64 k = 1;
189old_val = 0;
190new_val = 1;
191while (fabs(old_val - new_val) > tol)
192{
193old_val = new_val;
194new_val += 2 * s * exp(z * k * k);
195s *= -1;
196k++;
197}
198p = new_val;
199}
200p_value = 1 - p;
201}
202else
203{
204/* reference:
205* J. L. HODGES, Jr
206* The significance probability of the Smirnov two-sample test
207*/
208
209// Use Hodges' suggested approximation Eqn 5.3
210// Requires m to be the larger of (n1, n2)
211Float64 expt = -2 * p * p - 2 * p * (m + 2 * n) / sqrt(m * n * (m + n)) / 3.0;
212p_value = exp(expt);
213}
214}
215return {d, p_value};
216}
217
218};
219
220class AggregateFunctionKolmogorovSmirnov final:
221public IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov>
222{
223private:
224using Alternative = typename KolmogorovSmirnov::Alternative;
225Alternative alternative = Alternative::TwoSided;
226String method = "auto";
227
228public:
229explicit AggregateFunctionKolmogorovSmirnov(const DataTypes & arguments, const Array & params)
230: IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov> ({arguments}, {}, createResultType())
231{
232if (params.size() > 2)
233throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());
234
235if (params.empty())
236return;
237
238if (params[0].getType() != Field::Types::String)
239throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());
240
241const auto & param = params[0].get<String>();
242if (param == "two-sided")
243alternative = Alternative::TwoSided;
244else if (param == "less")
245alternative = Alternative::Less;
246else if (param == "greater")
247alternative = Alternative::Greater;
248else
249throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "
250"It must be one of: 'two-sided', 'less', 'greater'", getName());
251
252if (params.size() != 2)
253return;
254
255if (params[1].getType() != Field::Types::String)
256throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a String", getName());
257
258method = params[1].get<String>();
259if (method != "auto" && method != "exact" && method != "asymp" && method != "asymptotic")
260throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown method in aggregate function {}. "
261"It must be one of: 'auto', 'exact', 'asymp' (or 'asymptotic')", getName());
262}
263
264String getName() const override
265{
266return "kolmogorovSmirnovTest";
267}
268
269bool allocatesMemoryInArena() const override { return true; }
270
271static DataTypePtr createResultType()
272{
273DataTypes types
274{
275std::make_shared<DataTypeNumber<Float64>>(),
276std::make_shared<DataTypeNumber<Float64>>(),
277};
278
279Strings names
280{
281"d_statistic",
282"p_value"
283};
284
285return std::make_shared<DataTypeTuple>(
286std::move(types),
287std::move(names)
288);
289}
290
291void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
292{
293Float64 value = columns[0]->getFloat64(row_num);
294UInt8 is_second = columns[1]->getUInt(row_num);
295if (is_second)
296data(place).addY(value, arena);
297else
298data(place).addX(value, arena);
299}
300
301void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
302{
303data(place).merge(data(rhs), arena);
304}
305
306void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
307{
308data(place).write(buf);
309}
310
311void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
312{
313data(place).read(buf, arena);
314}
315
316void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
317{
318if (!data(place).size_x || !data(place).size_y)
319throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());
320
321auto [d_statistic, p_value] = data(place).getResult(alternative, method);
322
323/// Because p-value is a probability.
324p_value = std::min(1.0, std::max(0.0, p_value));
325
326auto & column_tuple = assert_cast<ColumnTuple &>(to);
327auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
328auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
329
330column_stat.getData().push_back(d_statistic);
331column_value.getData().push_back(p_value);
332}
333
334};
335
336
337AggregateFunctionPtr createAggregateFunctionKolmogorovSmirnovTest(
338const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
339{
340assertBinary(name, argument_types);
341
342if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
343throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Aggregate function {} only supports numerical types", name);
344
345return std::make_shared<AggregateFunctionKolmogorovSmirnov>(argument_types, parameters);
346}
347
348
349}
350
351void registerAggregateFunctionKolmogorovSmirnovTest(AggregateFunctionFactory & factory)
352{
353factory.registerFunction("kolmogorovSmirnovTest", createAggregateFunctionKolmogorovSmirnovTest, AggregateFunctionFactory::CaseInsensitive);
354}
355
356}
357