ClickHouse

Форк
0
/
AggregateFunctionHistogram.cpp 
426 строк · 13.1 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/FactoryHelpers.h>
3
#include <AggregateFunctions/Helpers.h>
4
#include <Common/FieldVisitorConvertToNumber.h>
5

6
#include <Common/NaNUtils.h>
7

8
#include <Columns/ColumnVector.h>
9
#include <Columns/ColumnTuple.h>
10
#include <Columns/ColumnArray.h>
11
#include <Common/assert_cast.h>
12

13
#include <DataTypes/DataTypesNumber.h>
14
#include <DataTypes/DataTypeArray.h>
15
#include <DataTypes/DataTypeTuple.h>
16

17
#include <IO/WriteBuffer.h>
18
#include <IO/ReadBuffer.h>
19
#include <IO/WriteHelpers.h>
20
#include <IO/ReadHelpers.h>
21
#include <IO/VarInt.h>
22

23
#include <AggregateFunctions/IAggregateFunction.h>
24

25
#include <queue>
26
#include <cmath>
27
#include <cstddef>
28

29

30
namespace DB
31
{
32
struct Settings;
33

34
namespace ErrorCodes
35
{
36
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
37
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
38
    extern const int BAD_ARGUMENTS;
39
    extern const int UNSUPPORTED_PARAMETER;
40
    extern const int PARAMETER_OUT_OF_BOUND;
41
    extern const int TOO_LARGE_ARRAY_SIZE;
42
    extern const int INCORRECT_DATA;
43
}
44

45

46
namespace
47
{
48

49
/** distance compression algorithm implementation
50
  * http://jmlr.org/papers/volume11/ben-haim10a/ben-haim10a.pdf
51
  */
52
class AggregateFunctionHistogramData
53
{
54
public:
55
    using Mean = Float64;
56
    using Weight = Float64;
57

58
    constexpr static size_t bins_count_limit = 250;
59

60
private:
61
    struct WeightedValue
62
    {
63
        Mean mean;
64
        Weight weight;
65

66
        WeightedValue operator+(const WeightedValue & other) const
67
        {
68
            return {mean + other.weight * (other.mean - mean) / (other.weight + weight), other.weight + weight};
69
        }
70
    };
71

72
    // quantity of stored weighted-values
73
    UInt32 size;
74

75
    // calculated lower and upper bounds of seen points
76
    Mean lower_bound;
77
    Mean upper_bound;
78

79
    // Weighted values representation of histogram.
80
    WeightedValue points[0];
81

82
    void sort()
83
    {
84
        ::sort(points, points + size,
85
            [](const WeightedValue & first, const WeightedValue & second)
86
            {
87
                return first.mean < second.mean;
88
            });
89
    }
90

91
    template <typename T>
92
    struct PriorityQueueStorage
93
    {
94
        size_t size = 0;
95
        T * data_ptr;
96

97
        explicit PriorityQueueStorage(T * value)
98
            : data_ptr(value)
99
        {
100
        }
101

102
        void push_back(T val) /// NOLINT
103
        {
104
            data_ptr[size] = std::move(val);
105
            ++size;
106
        }
107

108
        void pop_back() { --size; } /// NOLINT
109
        T * begin() { return data_ptr; }
110
        T * end() const { return data_ptr + size; }
111
        bool empty() const { return size == 0; }
112
        T & front() { return *data_ptr; }
113
        const T & front() const { return *data_ptr; }
114

115
        using value_type = T;
116
        using reference = T&;
117
        using const_reference = const T&;
118
        using size_type = size_t;
119
    };
120

121
    /**
122
     * Repeatedly fuse most close values until max_bins bins left
123
     */
124
    void compress(UInt32 max_bins)
125
    {
126
        sort();
127
        auto new_size = size;
128
        if (size <= max_bins)
129
            return;
130

131
        // Maintain doubly-linked list of "active" points
132
        // and store neighbour pairs in priority queue by distance
133
        UInt32 previous[size + 1];
134
        UInt32 next[size + 1];
135
        bool active[size + 1];
136
        std::fill(active, active + size, true);
137
        active[size] = false;
138

139
        auto delete_node = [&](UInt32 i)
140
        {
141
            previous[next[i]] = previous[i];
142
            next[previous[i]] = next[i];
143
            active[i] = false;
144
        };
145

146
        for (size_t i = 0; i <= size; ++i)
147
        {
148
            previous[i] = static_cast<UInt32>(i - 1);
149
            next[i] = static_cast<UInt32>(i + 1);
150
        }
151

152
        next[size] = 0;
153
        previous[0] = size;
154

155
        using QueueItem = std::pair<Mean, UInt32>;
156

157
        QueueItem storage[2 * size - max_bins];
158

159
        std::priority_queue<
160
            QueueItem,
161
            PriorityQueueStorage<QueueItem>,
162
            std::greater<>>
163
                queue{std::greater<>(),
164
                        PriorityQueueStorage<QueueItem>(storage)};
165

166
        auto quality = [&](UInt32 i) { return points[next[i]].mean - points[i].mean; };
167

168
        for (size_t i = 0; i + 1 < size; ++i)
169
            queue.push({quality(static_cast<UInt32>(i)), i});
170

171
        while (new_size > max_bins && !queue.empty())
172
        {
173
            auto min_item = queue.top();
174
            queue.pop();
175
            auto left = min_item.second;
176
            auto right = next[left];
177

178
            if (!active[left] || !active[right] || quality(left) > min_item.first)
179
                continue;
180

181
            points[left] = points[left] + points[right];
182

183
            delete_node(right);
184
            if (active[next[left]])
185
                queue.push({quality(left), left});
186
            if (active[previous[left]])
187
                queue.push({quality(previous[left]), previous[left]});
188

189
            --new_size;
190
        }
191

192
        size_t left = 0;
193
        for (size_t right = 0; right < size; ++right)
194
        {
195
            if (active[right])
196
            {
197
                points[left] = points[right];
198
                ++left;
199
            }
200
        }
201
        size = new_size;
202
    }
203

204
    /***
205
     * Delete too close points from histogram.
206
     * Assumes that points are sorted.
207
     */
208
    void unique()
209
    {
210
        if (size == 0)
211
            return;
212

213
        size_t left = 0;
214

215
        for (auto right = left + 1; right < size; ++right)
216
        {
217
            // Fuse points if their text representations differ only in last digit
218
            auto min_diff = 10 * (points[left].mean + points[right].mean) * std::numeric_limits<Mean>::epsilon();
219
            if (points[left].mean + std::fabs(min_diff) >= points[right].mean)
220
            {
221
                points[left] = points[left] + points[right];
222
            }
223
            else
224
            {
225
                ++left;
226
                points[left] = points[right];
227
            }
228
        }
229
        size = static_cast<UInt32>(left + 1);
230
    }
231

232
public:
233
    AggregateFunctionHistogramData()
234
        : size(0)
235
        , lower_bound(std::numeric_limits<Mean>::max())
236
        , upper_bound(std::numeric_limits<Mean>::lowest())
237
    {
238
        static_assert(offsetof(AggregateFunctionHistogramData, points) == sizeof(AggregateFunctionHistogramData), "points should be last member");
239
    }
240

241
    static size_t structSize(size_t max_bins)
242
    {
243
        return sizeof(AggregateFunctionHistogramData) + max_bins * 2 * sizeof(WeightedValue);
244
    }
245

246
    void insertResultInto(ColumnVector<Mean> & to_lower, ColumnVector<Mean> & to_upper, ColumnVector<Weight> & to_weights, UInt32 max_bins)
247
    {
248
        compress(max_bins);
249
        unique();
250

251
        for (size_t i = 0; i < size; ++i)
252
        {
253
            to_lower.insertValue((i == 0) ? lower_bound : (points[i].mean + points[i - 1].mean) / 2);
254
            to_upper.insertValue((i + 1 == size) ? upper_bound : (points[i].mean + points[i + 1].mean) / 2);
255

256
            // linear density approximation
257
            Weight lower_weight = (i == 0) ? points[i].weight : ((points[i - 1].weight) + points[i].weight * 3) / 4;
258
            Weight upper_weight = (i + 1 == size) ? points[i].weight : (points[i + 1].weight + points[i].weight * 3) / 4;
259
            to_weights.insertValue((lower_weight + upper_weight) / 2);
260
        }
261
    }
262

263
    void add(Mean value, Weight weight, UInt32 max_bins)
264
    {
265
        // nans break sort and compression
266
        // infs don't fit in bins partition method
267
        if (!isFinite(value))
268
            throw Exception(ErrorCodes::INCORRECT_DATA, "Invalid value (inf or nan) for aggregation by 'histogram' function");
269

270
        points[size] = {value, weight};
271
        ++size;
272
        lower_bound = std::min(lower_bound, value);
273
        upper_bound = std::max(upper_bound, value);
274

275
        if (size >= max_bins * 2)
276
            compress(max_bins);
277
    }
278

279
    void merge(const AggregateFunctionHistogramData & other, UInt32 max_bins)
280
    {
281
        lower_bound = std::min(lower_bound, other.lower_bound);
282
        upper_bound = std::max(upper_bound, other.upper_bound);
283
        for (size_t i = 0; i < other.size; ++i)
284
            add(other.points[i].mean, other.points[i].weight, max_bins);
285
    }
286

287
    void write(WriteBuffer & buf) const
288
    {
289
        writeBinary(lower_bound, buf);
290
        writeBinary(upper_bound, buf);
291

292
        writeVarUInt(size, buf);
293
        buf.write(reinterpret_cast<const char *>(points), size * sizeof(WeightedValue));
294
    }
295

296
    void read(ReadBuffer & buf, UInt32 max_bins)
297
    {
298
        readBinary(lower_bound, buf);
299
        readBinary(upper_bound, buf);
300

301
        readVarUInt(size, buf);
302
        if (size > max_bins * 2)
303
            throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too many bins");
304
        static constexpr size_t max_size = 1_GiB;
305
        if (size > max_size)
306
            throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
307
                            "Too large array size in histogram (maximum: {})", max_size);
308

309
        buf.readStrict(reinterpret_cast<char *>(points), size * sizeof(WeightedValue));
310
    }
311
};
312

313
template <typename T>
314
class AggregateFunctionHistogram final: public IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>
315
{
316
private:
317
    using Data = AggregateFunctionHistogramData;
318

319
    const UInt32 max_bins;
320

321
public:
322
    AggregateFunctionHistogram(UInt32 max_bins_, const DataTypes & arguments, const Array & params)
323
        : IAggregateFunctionDataHelper<AggregateFunctionHistogramData, AggregateFunctionHistogram<T>>(arguments, params, createResultType())
324
        , max_bins(max_bins_)
325
    {
326
    }
327

328
    size_t sizeOfData() const override
329
    {
330
        return Data::structSize(max_bins);
331
    }
332
    static DataTypePtr createResultType()
333
    {
334
        DataTypes types;
335
        auto mean = std::make_shared<DataTypeNumber<Data::Mean>>();
336
        auto weight = std::make_shared<DataTypeNumber<Data::Weight>>();
337

338
        // lower bound
339
        types.emplace_back(mean);
340
        // upper bound
341
        types.emplace_back(mean);
342
        // weight
343
        types.emplace_back(weight);
344

345
        auto tuple = std::make_shared<DataTypeTuple>(types);
346
        return std::make_shared<DataTypeArray>(tuple);
347
    }
348

349
    bool allocatesMemoryInArena() const override { return false; }
350

351
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
352
    {
353
        auto val = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
354
        this->data(place).add(static_cast<Data::Mean>(val), 1, max_bins);
355
    }
356

357
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
358
    {
359
        this->data(place).merge(this->data(rhs), max_bins);
360
    }
361

362
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
363
    {
364
        this->data(place).write(buf);
365
    }
366

367
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
368
    {
369
        this->data(place).read(buf, max_bins);
370
    }
371

372
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
373
    {
374
        auto & data = this->data(place);
375

376
        auto & to_array = assert_cast<ColumnArray &>(to);
377
        ColumnArray::Offsets & offsets_to = to_array.getOffsets();
378
        auto & to_tuple = assert_cast<ColumnTuple &>(to_array.getData());
379

380
        auto & to_lower = assert_cast<ColumnVector<Data::Mean> &>(to_tuple.getColumn(0));
381
        auto & to_upper = assert_cast<ColumnVector<Data::Mean> &>(to_tuple.getColumn(1));
382
        auto & to_weights = assert_cast<ColumnVector<Data::Weight> &>(to_tuple.getColumn(2));
383
        data.insertResultInto(to_lower, to_upper, to_weights, max_bins);
384

385
        offsets_to.push_back(to_tuple.size());
386
    }
387

388
    String getName() const override { return "histogram"; }
389
};
390

391

392
AggregateFunctionPtr createAggregateFunctionHistogram(const std::string & name, const DataTypes & arguments, const Array & params, const Settings *)
393
{
394
    if (params.size() != 1)
395
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires single parameter: bins count", name);
396

397
    if (params[0].getType() != Field::Types::UInt64)
398
        throw Exception(ErrorCodes::UNSUPPORTED_PARAMETER, "Invalid type for bins count");
399

400
    UInt32 bins_count = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), params[0]);
401

402
    auto limit = AggregateFunctionHistogramData::bins_count_limit;
403
    if (bins_count > limit)
404
        throw Exception(ErrorCodes::PARAMETER_OUT_OF_BOUND, "Unsupported bins count. Should not be greater than {}", limit);
405

406
    if (bins_count == 0)
407
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Bin count should be positive");
408

409
    assertUnary(name, arguments);
410
    AggregateFunctionPtr res(createWithNumericType<AggregateFunctionHistogram>(*arguments[0], bins_count, arguments, params));
411

412
    if (!res)
413
        throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
414
                        "Illegal type {} of argument for aggregate function {}", arguments[0]->getName(), name);
415

416
    return res;
417
}
418

419
}
420

421
void registerAggregateFunctionHistogram(AggregateFunctionFactory & factory)
422
{
423
    factory.registerFunction("histogram", createAggregateFunctionHistogram);
424
}
425

426
}
427

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.