ClickHouse
183 строки · 5.2 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/FactoryHelpers.h>
3#include <AggregateFunctions/Helpers.h>
4
5#include <Common/HashTable/HashMap.h>
6#include <Common/NaNUtils.h>
7
8#include <AggregateFunctions/IAggregateFunction.h>
9#include <AggregateFunctions/UniqVariadicHash.h>
10#include <DataTypes/DataTypesNumber.h>
11#include <Columns/ColumnVector.h>
12#include <Common/assert_cast.h>
13
14#include <cmath>
15
16
17namespace DB
18{
19struct Settings;
20
21namespace ErrorCodes
22{
23extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
24}
25
26namespace
27{
28
29/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function.
30* Entropy is measured in bits (base-2 logarithm is used).
31*/
32template <typename Value>
33struct EntropyData
34{
35using Weight = UInt64;
36
37using HashingMap = HashMapWithStackMemory<Value, Weight, HashCRC32<Value>, 4>;
38
39/// For the case of pre-hashed values.
40using TrivialMap = HashMapWithStackMemory<Value, Weight, UInt128TrivialHash, 4>;
41
42using Map = std::conditional_t<std::is_same_v<UInt128, Value>, TrivialMap, HashingMap>;
43
44Map map;
45
46void add(const Value & x)
47{
48if (!isNaN(x))
49++map[x];
50}
51
52void add(const Value & x, const Weight & weight)
53{
54if (!isNaN(x))
55map[x] += weight;
56}
57
58void merge(const EntropyData & rhs)
59{
60for (const auto & pair : rhs.map)
61map[pair.getKey()] += pair.getMapped();
62}
63
64void serialize(WriteBuffer & buf) const
65{
66map.write(buf);
67}
68
69void deserialize(ReadBuffer & buf)
70{
71typename Map::Reader reader(buf);
72while (reader.next())
73{
74const auto & pair = reader.get();
75map[pair.first] = pair.second;
76}
77}
78
79Float64 get() const
80{
81UInt64 total_value = 0;
82for (const auto & pair : map)
83total_value += pair.getMapped();
84
85Float64 shannon_entropy = 0;
86for (const auto & pair : map)
87{
88Float64 frequency = Float64(pair.getMapped()) / total_value;
89shannon_entropy -= frequency * log2(frequency);
90}
91
92return shannon_entropy;
93}
94};
95
96
97template <typename Value>
98class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>
99{
100private:
101size_t num_args;
102
103public:
104explicit AggregateFunctionEntropy(const DataTypes & argument_types_)
105: IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}, createResultType())
106, num_args(argument_types_.size())
107{
108}
109
110String getName() const override { return "entropy"; }
111
112static DataTypePtr createResultType()
113{
114return std::make_shared<DataTypeNumber<Float64>>();
115}
116
117bool allocatesMemoryInArena() const override { return false; }
118
119void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
120{
121if constexpr (!std::is_same_v<UInt128, Value>)
122{
123/// Here we manage only with numerical types
124const auto & column = assert_cast<const ColumnVector <Value> &>(*columns[0]);
125this->data(place).add(column.getData()[row_num]);
126}
127else
128{
129this->data(place).add(UniqVariadicHash<true, false>::apply(num_args, columns, row_num));
130}
131}
132
133void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
134{
135this->data(place).merge(this->data(rhs));
136}
137
138void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
139{
140this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
141}
142
143void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
144{
145this->data(place).deserialize(buf);
146}
147
148void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
149{
150auto & column = assert_cast<ColumnVector<Float64> &>(to);
151column.getData().push_back(this->data(place).get());
152}
153};
154
155
156AggregateFunctionPtr createAggregateFunctionEntropy(
157const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
158{
159assertNoParameters(name, parameters);
160if (argument_types.empty())
161throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
162"Incorrect number of arguments for aggregate function {}", name);
163
164size_t num_args = argument_types.size();
165if (num_args == 1)
166{
167/// Specialized implementation for single argument of numeric type.
168if (auto * res = createWithNumericBasedType<AggregateFunctionEntropy>(*argument_types[0], argument_types))
169return AggregateFunctionPtr(res);
170}
171
172/// Generic implementation for other types or for multiple arguments.
173return std::make_shared<AggregateFunctionEntropy<UInt128>>(argument_types);
174}
175
176}
177
178void registerAggregateFunctionEntropy(AggregateFunctionFactory & factory)
179{
180factory.registerFunction("entropy", createAggregateFunctionEntropy);
181}
182
183}
184