ClickHouse
271 строка · 9.0 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>2#include <AggregateFunctions/FactoryHelpers.h>3
4#include <AggregateFunctions/IAggregateFunction.h>5#include <AggregateFunctions/StatCommon.h>6#include <Columns/ColumnVector.h>7#include <Columns/ColumnTuple.h>8#include <Common/assert_cast.h>9#include <Common/PODArray.h>10#include <DataTypes/DataTypesDecimal.h>11#include <DataTypes/DataTypeNullable.h>12#include <DataTypes/DataTypesNumber.h>13#include <DataTypes/DataTypeTuple.h>14#include <IO/ReadHelpers.h>15#include <limits>16
17#include <boost/math/distributions/normal.hpp>18
19
20namespace ErrorCodes21{
22extern const int NOT_IMPLEMENTED;23extern const int ILLEGAL_TYPE_OF_ARGUMENT;24extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;25extern const int BAD_ARGUMENTS;26}
27
28namespace DB29{
30
31struct Settings;32
33namespace
34{
35
36struct MannWhitneyData : public StatisticalSample<Float64, Float64>37{
38/*Since null hypothesis is "for randomly selected values X and Y from two populations,39*the probability of X being greater than Y is equal to the probability of Y being greater than X".
40*Or "the distribution F of first sample equals to the distribution G of second sample".
41*Then alternative for this hypothesis (H1) is "two-sided"(F != G), "less"(F < G), "greater" (F > G). */
42enum class Alternative43{44TwoSided,45Less,46Greater
47};48
49/// The behaviour equals to the similar function from scipy.50/// https://github.com/scipy/scipy/blob/ab9e9f17e0b7b2d618c4d4d8402cd4c0c200d6c0/scipy/stats/stats.py#L697851std::pair<Float64, Float64> getResult(Alternative alternative, bool continuity_correction)52{53ConcatenatedSamples both(this->x, this->y);54RanksArray ranks;55Float64 tie_correction;56
57/// Compute ranks according to both samples.58std::tie(ranks, tie_correction) = computeRanksAndTieCorrection(both);59
60const Float64 n1 = this->size_x;61const Float64 n2 = this->size_y;62
63Float64 r1 = 0;64for (size_t i = 0; i < n1; ++i)65r1 += ranks[i];66
67const Float64 u1 = n1 * n2 + (n1 * (n1 + 1.)) / 2. - r1;68const Float64 u2 = n1 * n2 - u1;69
70/// The distribution of U-statistic under null hypothesis H0 is symmetric with respect to meanrank.71const Float64 meanrank = n1 * n2 /2. + 0.5 * continuity_correction;72const Float64 sd = std::sqrt(tie_correction * n1 * n2 * (n1 + n2 + 1) / 12.0);73
74Float64 u = 0;75if (alternative == Alternative::TwoSided)76/// There is no difference which u_i to take as u, because z will be differ only in sign and we take std::abs() from it.77u = std::max(u1, u2);78else if (alternative == Alternative::Less)79u = u1;80else if (alternative == Alternative::Greater)81u = u2;82
83Float64 z = (u - meanrank) / sd;84
85if (unlikely(!std::isfinite(z)))86return {std::numeric_limits<Float64>::quiet_NaN(), std::numeric_limits<Float64>::quiet_NaN()};87
88if (alternative == Alternative::TwoSided)89z = std::abs(z);90
91auto standard_normal_distribution = boost::math::normal_distribution<Float64>();92auto cdf = boost::math::cdf(standard_normal_distribution, z);93
94Float64 p_value = 0;95if (alternative == Alternative::TwoSided)96p_value = 2 - 2 * cdf;97else98p_value = 1 - cdf;99
100return {u2, p_value};101}102
103private:104using Sample = typename StatisticalSample<Float64, Float64>::SampleX;105
106/// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation.107class ConcatenatedSamples108{109public:110ConcatenatedSamples(const Sample & first_, const Sample & second_)111: first(first_), second(second_) {}112
113const Float64 & operator[](size_t ind) const114{115if (ind < first.size())116return first[ind];117return second[ind % first.size()];118}119
120size_t size() const121{122return first.size() + second.size();123}124
125private:126const Sample & first;127const Sample & second;128};129};130
131class AggregateFunctionMannWhitney final:132public IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney>133{
134private:135using Alternative = typename MannWhitneyData::Alternative;136Alternative alternative;137bool continuity_correction{true};138
139public:140explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)141: IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {}, createResultType())142{143if (params.size() > 2)144throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());145
146if (params.empty())147{148alternative = Alternative::TwoSided;149return;150}151
152if (params[0].getType() != Field::Types::String)153throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());154
155const auto & param = params[0].get<String>();156if (param == "two-sided")157alternative = Alternative::TwoSided;158else if (param == "less")159alternative = Alternative::Less;160else if (param == "greater")161alternative = Alternative::Greater;162else163throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "164"It must be one of: 'two-sided', 'less', 'greater'", getName());165
166if (params.size() != 2)167return;168
169if (params[1].getType() != Field::Types::UInt64)170throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a UInt64", getName());171
172continuity_correction = static_cast<bool>(params[1].get<UInt64>());173}174
175String getName() const override176{177return "mannWhitneyUTest";178}179
180bool allocatesMemoryInArena() const override { return true; }181
182static DataTypePtr createResultType()183{184DataTypes types
185{186std::make_shared<DataTypeNumber<Float64>>(),187std::make_shared<DataTypeNumber<Float64>>(),188};189
190Strings names
191{192"u_statistic",193"p_value"194};195
196return std::make_shared<DataTypeTuple>(197std::move(types),198std::move(names)199);200}201
202void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override203{204Float64 value = columns[0]->getFloat64(row_num);205UInt8 is_second = columns[1]->getUInt(row_num);206
207if (is_second)208data(place).addY(value, arena);209else210data(place).addX(value, arena);211}212
213void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override214{215auto & a = data(place);216const auto & b = data(rhs);217
218a.merge(b, arena);219}220
221void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override222{223data(place).write(buf);224}225
226void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override227{228data(place).read(buf, arena);229}230
231void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override232{233if (!data(place).size_x || !data(place).size_y)234throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());235
236auto [u_statistic, p_value] = data(place).getResult(alternative, continuity_correction);237
238/// Because p-value is a probability.239p_value = std::min(1.0, std::max(0.0, p_value));240
241auto & column_tuple = assert_cast<ColumnTuple &>(to);242auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));243auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));244
245column_stat.getData().push_back(u_statistic);246column_value.getData().push_back(p_value);247}248
249};250
251
252AggregateFunctionPtr createAggregateFunctionMannWhitneyUTest(253const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)254{
255assertBinary(name, argument_types);256
257if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))258throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Aggregate function {} only supports numerical types", name);259
260return std::make_shared<AggregateFunctionMannWhitney>(argument_types, parameters);261}
262
263}
264
265
266void registerAggregateFunctionMannWhitney(AggregateFunctionFactory & factory)267{
268factory.registerFunction("mannWhitneyUTest", createAggregateFunctionMannWhitneyUTest);269}
270
271}
272