ClickHouse

Форк
0
/
AggregateFunctionMannWhitney.cpp 
271 строка · 9.0 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/FactoryHelpers.h>
3

4
#include <AggregateFunctions/IAggregateFunction.h>
5
#include <AggregateFunctions/StatCommon.h>
6
#include <Columns/ColumnVector.h>
7
#include <Columns/ColumnTuple.h>
8
#include <Common/assert_cast.h>
9
#include <Common/PODArray.h>
10
#include <DataTypes/DataTypesDecimal.h>
11
#include <DataTypes/DataTypeNullable.h>
12
#include <DataTypes/DataTypesNumber.h>
13
#include <DataTypes/DataTypeTuple.h>
14
#include <IO/ReadHelpers.h>
15
#include <limits>
16

17
#include <boost/math/distributions/normal.hpp>
18

19

20
namespace ErrorCodes
21
{
22
    extern const int NOT_IMPLEMENTED;
23
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
24
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
25
    extern const int BAD_ARGUMENTS;
26
}
27

28
namespace DB
29
{
30

31
struct Settings;
32

33
namespace
34
{
35

36
struct MannWhitneyData : public StatisticalSample<Float64, Float64>
37
{
38
    /*Since null hypothesis is "for randomly selected values X and Y from two populations,
39
     *the probability of X being greater than Y is equal to the probability of Y being greater than X".
40
     *Or "the distribution F of first sample equals to the distribution G of second sample".
41
     *Then alternative for this hypothesis (H1) is "two-sided"(F != G), "less"(F < G), "greater" (F > G). */
42
    enum class Alternative
43
    {
44
        TwoSided,
45
        Less,
46
        Greater
47
    };
48

49
    /// The behaviour equals to the similar function from scipy.
50
    /// https://github.com/scipy/scipy/blob/ab9e9f17e0b7b2d618c4d4d8402cd4c0c200d6c0/scipy/stats/stats.py#L6978
51
    std::pair<Float64, Float64> getResult(Alternative alternative, bool continuity_correction)
52
    {
53
        ConcatenatedSamples both(this->x, this->y);
54
        RanksArray ranks;
55
        Float64 tie_correction;
56

57
        /// Compute ranks according to both samples.
58
        std::tie(ranks, tie_correction) = computeRanksAndTieCorrection(both);
59

60
        const Float64 n1 = this->size_x;
61
        const Float64 n2 = this->size_y;
62

63
        Float64 r1 = 0;
64
        for (size_t i = 0; i < n1; ++i)
65
            r1 += ranks[i];
66

67
        const Float64 u1 = n1 * n2 + (n1 * (n1 + 1.)) / 2. - r1;
68
        const Float64 u2 = n1 * n2 - u1;
69

70
        /// The distribution of U-statistic under null hypothesis H0  is symmetric with respect to meanrank.
71
        const Float64 meanrank = n1 * n2 /2. + 0.5 * continuity_correction;
72
        const Float64 sd = std::sqrt(tie_correction * n1 * n2 * (n1 + n2 + 1) / 12.0);
73

74
        Float64 u = 0;
75
        if (alternative == Alternative::TwoSided)
76
            /// There is no difference which u_i to take as u, because z will be differ only in sign and we take std::abs() from it.
77
            u = std::max(u1, u2);
78
        else if (alternative == Alternative::Less)
79
            u = u1;
80
        else if (alternative == Alternative::Greater)
81
            u = u2;
82

83
        Float64 z = (u - meanrank) / sd;
84

85
        if (unlikely(!std::isfinite(z)))
86
            return {std::numeric_limits<Float64>::quiet_NaN(), std::numeric_limits<Float64>::quiet_NaN()};
87

88
        if (alternative == Alternative::TwoSided)
89
            z = std::abs(z);
90

91
        auto standard_normal_distribution = boost::math::normal_distribution<Float64>();
92
        auto cdf = boost::math::cdf(standard_normal_distribution, z);
93

94
        Float64 p_value = 0;
95
        if (alternative == Alternative::TwoSided)
96
            p_value = 2 - 2 * cdf;
97
        else
98
            p_value = 1 - cdf;
99

100
        return {u2, p_value};
101
    }
102

103
private:
104
    using Sample = typename StatisticalSample<Float64, Float64>::SampleX;
105

106
    /// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation.
107
    class ConcatenatedSamples
108
    {
109
        public:
110
            ConcatenatedSamples(const Sample & first_, const Sample & second_)
111
                : first(first_), second(second_) {}
112

113
            const Float64 & operator[](size_t ind) const
114
            {
115
                if (ind < first.size())
116
                    return first[ind];
117
                return second[ind % first.size()];
118
            }
119

120
            size_t size() const
121
            {
122
                return first.size() + second.size();
123
            }
124

125
        private:
126
            const Sample & first;
127
            const Sample & second;
128
    };
129
};
130

131
class AggregateFunctionMannWhitney final:
132
    public IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney>
133
{
134
private:
135
    using Alternative = typename MannWhitneyData::Alternative;
136
    Alternative alternative;
137
    bool continuity_correction{true};
138

139
public:
140
    explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)
141
        : IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {}, createResultType())
142
    {
143
        if (params.size() > 2)
144
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());
145

146
        if (params.empty())
147
        {
148
            alternative = Alternative::TwoSided;
149
            return;
150
        }
151

152
        if (params[0].getType() != Field::Types::String)
153
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());
154

155
        const auto & param = params[0].get<String>();
156
        if (param == "two-sided")
157
            alternative = Alternative::TwoSided;
158
        else if (param == "less")
159
            alternative = Alternative::Less;
160
        else if (param == "greater")
161
            alternative = Alternative::Greater;
162
        else
163
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "
164
                    "It must be one of: 'two-sided', 'less', 'greater'", getName());
165

166
        if (params.size() != 2)
167
            return;
168

169
        if (params[1].getType() != Field::Types::UInt64)
170
                throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a UInt64", getName());
171

172
        continuity_correction = static_cast<bool>(params[1].get<UInt64>());
173
    }
174

175
    String getName() const override
176
    {
177
        return "mannWhitneyUTest";
178
    }
179

180
    bool allocatesMemoryInArena() const override { return true; }
181

182
    static DataTypePtr createResultType()
183
    {
184
        DataTypes types
185
        {
186
            std::make_shared<DataTypeNumber<Float64>>(),
187
            std::make_shared<DataTypeNumber<Float64>>(),
188
        };
189

190
        Strings names
191
        {
192
            "u_statistic",
193
            "p_value"
194
        };
195

196
        return std::make_shared<DataTypeTuple>(
197
            std::move(types),
198
            std::move(names)
199
        );
200
    }
201

202
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
203
    {
204
        Float64 value = columns[0]->getFloat64(row_num);
205
        UInt8 is_second = columns[1]->getUInt(row_num);
206

207
        if (is_second)
208
            data(place).addY(value, arena);
209
        else
210
            data(place).addX(value, arena);
211
    }
212

213
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
214
    {
215
        auto & a = data(place);
216
        const auto & b = data(rhs);
217

218
        a.merge(b, arena);
219
    }
220

221
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
222
    {
223
        data(place).write(buf);
224
    }
225

226
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
227
    {
228
        data(place).read(buf, arena);
229
    }
230

231
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
232
    {
233
        if (!data(place).size_x || !data(place).size_y)
234
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());
235

236
        auto [u_statistic, p_value] = data(place).getResult(alternative, continuity_correction);
237

238
        /// Because p-value is a probability.
239
        p_value = std::min(1.0, std::max(0.0, p_value));
240

241
        auto & column_tuple = assert_cast<ColumnTuple &>(to);
242
        auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
243
        auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
244

245
        column_stat.getData().push_back(u_statistic);
246
        column_value.getData().push_back(p_value);
247
    }
248

249
};
250

251

252
AggregateFunctionPtr createAggregateFunctionMannWhitneyUTest(
253
    const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
254
{
255
    assertBinary(name, argument_types);
256

257
    if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
258
        throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Aggregate function {} only supports numerical types", name);
259

260
    return std::make_shared<AggregateFunctionMannWhitney>(argument_types, parameters);
261
}
262

263
}
264

265

266
void registerAggregateFunctionMannWhitney(AggregateFunctionFactory & factory)
267
{
268
    factory.registerFunction("mannWhitneyUTest", createAggregateFunctionMannWhitneyUTest);
269
}
270

271
}
272

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.