ClickHouse

Форк
0
/
AggregateFunctionEntropy.cpp 
183 строки · 5.2 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/FactoryHelpers.h>
3
#include <AggregateFunctions/Helpers.h>
4

5
#include <Common/HashTable/HashMap.h>
6
#include <Common/NaNUtils.h>
7

8
#include <AggregateFunctions/IAggregateFunction.h>
9
#include <AggregateFunctions/UniqVariadicHash.h>
10
#include <DataTypes/DataTypesNumber.h>
11
#include <Columns/ColumnVector.h>
12
#include <Common/assert_cast.h>
13

14
#include <cmath>
15

16

17
namespace DB
18
{
19
struct Settings;
20

21
namespace ErrorCodes
22
{
23
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
24
}
25

26
namespace
27
{
28

29
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function.
30
  * Entropy is measured in bits (base-2 logarithm is used).
31
  */
32
template <typename Value>
33
struct EntropyData
34
{
35
    using Weight = UInt64;
36

37
    using HashingMap = HashMapWithStackMemory<Value, Weight, HashCRC32<Value>, 4>;
38

39
    /// For the case of pre-hashed values.
40
    using TrivialMap = HashMapWithStackMemory<Value, Weight, UInt128TrivialHash, 4>;
41

42
    using Map = std::conditional_t<std::is_same_v<UInt128, Value>, TrivialMap, HashingMap>;
43

44
    Map map;
45

46
    void add(const Value & x)
47
    {
48
        if (!isNaN(x))
49
            ++map[x];
50
    }
51

52
    void add(const Value & x, const Weight & weight)
53
    {
54
        if (!isNaN(x))
55
            map[x] += weight;
56
    }
57

58
    void merge(const EntropyData & rhs)
59
    {
60
        for (const auto & pair : rhs.map)
61
            map[pair.getKey()] += pair.getMapped();
62
    }
63

64
    void serialize(WriteBuffer & buf) const
65
    {
66
        map.write(buf);
67
    }
68

69
    void deserialize(ReadBuffer & buf)
70
    {
71
        typename Map::Reader reader(buf);
72
        while (reader.next())
73
        {
74
            const auto & pair = reader.get();
75
            map[pair.first] = pair.second;
76
        }
77
    }
78

79
    Float64 get() const
80
    {
81
        UInt64 total_value = 0;
82
        for (const auto & pair : map)
83
            total_value += pair.getMapped();
84

85
        Float64 shannon_entropy = 0;
86
        for (const auto & pair : map)
87
        {
88
            Float64 frequency = Float64(pair.getMapped()) / total_value;
89
            shannon_entropy -= frequency * log2(frequency);
90
        }
91

92
        return shannon_entropy;
93
    }
94
};
95

96

97
template <typename Value>
98
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>
99
{
100
private:
101
    size_t num_args;
102

103
public:
104
    explicit AggregateFunctionEntropy(const DataTypes & argument_types_)
105
        : IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}, createResultType())
106
        , num_args(argument_types_.size())
107
    {
108
    }
109

110
    String getName() const override { return "entropy"; }
111

112
    static DataTypePtr createResultType()
113
    {
114
        return std::make_shared<DataTypeNumber<Float64>>();
115
    }
116

117
    bool allocatesMemoryInArena() const override { return false; }
118

119
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
120
    {
121
        if constexpr (!std::is_same_v<UInt128, Value>)
122
        {
123
            /// Here we manage only with numerical types
124
            const auto & column = assert_cast<const ColumnVector <Value> &>(*columns[0]);
125
            this->data(place).add(column.getData()[row_num]);
126
        }
127
        else
128
        {
129
            this->data(place).add(UniqVariadicHash<true, false>::apply(num_args, columns, row_num));
130
        }
131
    }
132

133
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
134
    {
135
        this->data(place).merge(this->data(rhs));
136
    }
137

138
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
139
    {
140
        this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
141
    }
142

143
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
144
    {
145
        this->data(place).deserialize(buf);
146
    }
147

148
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
149
    {
150
        auto & column = assert_cast<ColumnVector<Float64> &>(to);
151
        column.getData().push_back(this->data(place).get());
152
    }
153
};
154

155

156
AggregateFunctionPtr createAggregateFunctionEntropy(
157
    const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
158
{
159
    assertNoParameters(name, parameters);
160
    if (argument_types.empty())
161
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
162
                        "Incorrect number of arguments for aggregate function {}", name);
163

164
    size_t num_args = argument_types.size();
165
    if (num_args == 1)
166
    {
167
        /// Specialized implementation for single argument of numeric type.
168
        if (auto * res = createWithNumericBasedType<AggregateFunctionEntropy>(*argument_types[0], argument_types))
169
            return AggregateFunctionPtr(res);
170
    }
171

172
    /// Generic implementation for other types or for multiple arguments.
173
    return std::make_shared<AggregateFunctionEntropy<UInt128>>(argument_types);
174
}
175

176
}
177

178
void registerAggregateFunctionEntropy(AggregateFunctionFactory & factory)
179
{
180
    factory.registerFunction("entropy", createAggregateFunctionEntropy);
181
}
182

183
}
184

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.