ClickHouse
186 строк · 6.2 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>2#include <AggregateFunctions/FactoryHelpers.h>3#include <AggregateFunctions/Moments.h>4
5#include <AggregateFunctions/IAggregateFunction.h>6#include <AggregateFunctions/StatCommon.h>7#include <Columns/ColumnVector.h>8#include <Columns/ColumnTuple.h>9#include <Common/assert_cast.h>10#include <DataTypes/DataTypesNumber.h>11#include <DataTypes/DataTypeTuple.h>12#include <cmath>13
14
15namespace ErrorCodes16{
17extern const int BAD_ARGUMENTS;18extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;19}
20
21
22namespace DB23{
24struct Settings;25
26namespace
27{
28
29/// Returns tuple of (z-statistic, p-value, confidence-interval-low, confidence-interval-high)
30template <typename Data>31class AggregateFunctionMeanZTest :32public IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>33{
34private:35Float64 pop_var_x;36Float64 pop_var_y;37Float64 confidence_level;38
39public:40AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)41: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params, createResultType())42{43pop_var_x = params.at(0).safeGet<Float64>();44pop_var_y = params.at(1).safeGet<Float64>();45confidence_level = params.at(2).safeGet<Float64>();46
47if (!std::isfinite(pop_var_x) || !std::isfinite(pop_var_y) || !std::isfinite(confidence_level))48{49throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} requires finite parameter values.", Data::name);50}51
52if (pop_var_x < 0.0 || pop_var_y < 0.0)53{54throw Exception(ErrorCodes::BAD_ARGUMENTS,55"Population variance parameters must be larger than or equal to zero "56"in aggregate function {}.", Data::name);57}58
59if (confidence_level <= 0.0 || confidence_level >= 1.0)60{61throw Exception(ErrorCodes::BAD_ARGUMENTS, "Confidence level parameter must be between 0 and 1 in aggregate function {}.", Data::name);62}63}64
65String getName() const override66{67return Data::name;68}69
70static DataTypePtr createResultType()71{72DataTypes types
73{74std::make_shared<DataTypeNumber<Float64>>(),75std::make_shared<DataTypeNumber<Float64>>(),76std::make_shared<DataTypeNumber<Float64>>(),77std::make_shared<DataTypeNumber<Float64>>(),78};79
80Strings names
81{82"z_statistic",83"p_value",84"confidence_interval_low",85"confidence_interval_high"86};87
88return std::make_shared<DataTypeTuple>(89std::move(types),90std::move(names)91);92}93
94bool allocatesMemoryInArena() const override { return false; }95
96void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override97{98Float64 value = columns[0]->getFloat64(row_num);99UInt8 is_second = columns[1]->getUInt(row_num);100
101if (is_second)102this->data(place).addY(value);103else104this->data(place).addX(value);105}106
107void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override108{109this->data(place).merge(this->data(rhs));110}111
112void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override113{114this->data(place).write(buf);115}116
117void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override118{119this->data(place).read(buf);120}121
122void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override123{124auto [z_stat, p_value] = this->data(place).getResult(pop_var_x, pop_var_y);125auto [ci_low, ci_high] = this->data(place).getConfidenceIntervals(pop_var_x, pop_var_y, confidence_level);126
127/// Because p-value is a probability.128p_value = std::min(1.0, std::max(0.0, p_value));129
130auto & column_tuple = assert_cast<ColumnTuple &>(to);131auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));132auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));133auto & column_ci_low = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(2));134auto & column_ci_high = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(3));135
136column_stat.getData().push_back(z_stat);137column_value.getData().push_back(p_value);138column_ci_low.getData().push_back(ci_low);139column_ci_high.getData().push_back(ci_high);140}141};142
143
144struct MeanZTestData : public ZTestMoments<Float64>145{
146static constexpr auto name = "meanZTest";147
148std::pair<Float64, Float64> getResult(Float64 pop_var_x, Float64 pop_var_y) const149{150Float64 mean_x = getMeanX();151Float64 mean_y = getMeanY();152
153/// z = \frac{\bar{X_{1}} - \bar{X_{2}}}{\sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}}154Float64 zstat = (mean_x - mean_y) / getStandardError(pop_var_x, pop_var_y);155
156if (unlikely(!std::isfinite(zstat)))157return {std::numeric_limits<Float64>::quiet_NaN(), std::numeric_limits<Float64>::quiet_NaN()};158
159Float64 pvalue = 2.0 * boost::math::cdf(boost::math::normal(0.0, 1.0), -1.0 * std::abs(zstat));160
161return {zstat, pvalue};162}163};164
165AggregateFunctionPtr createAggregateFunctionMeanZTest(166const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)167{
168assertBinary(name, argument_types);169
170if (parameters.size() != 3)171throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires three parameter.", name);172
173if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))174throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} only supports numerical types", name);175
176return std::make_shared<AggregateFunctionMeanZTest<MeanZTestData>>(argument_types, parameters);177}
178
179}
180
181void registerAggregateFunctionMeanZTest(AggregateFunctionFactory & factory)182{
183factory.registerFunction("meanZTest", createAggregateFunctionMeanZTest);184}
185
186}
187