ClickHouse

Форк
0
/
AggregateFunctionGroupArrayInsertAt.cpp 
234 строки · 8.1 Кб
1
#include <AggregateFunctions/AggregateFunctionFactory.h>
2
#include <AggregateFunctions/FactoryHelpers.h>
3

4
#include <IO/WriteHelpers.h>
5
#include <IO/ReadHelpers.h>
6

7
#include <DataTypes/DataTypeArray.h>
8
#include <DataTypes/DataTypesNumber.h>
9

10
#include <Columns/ColumnArray.h>
11

12
#include <Common/FieldVisitorToString.h>
13
#include <Common/FieldVisitorConvertToNumber.h>
14
#include <Common/assert_cast.h>
15
#include <Interpreters/convertFieldToType.h>
16

17
#include <AggregateFunctions/IAggregateFunction.h>
18

19
#define AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE 0xFFFFFF
20

21

22
namespace DB
23
{
24

25
struct Settings;
26

27
namespace ErrorCodes
28
{
29
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
30
    extern const int TOO_LARGE_ARRAY_SIZE;
31
    extern const int CANNOT_CONVERT_TYPE;
32
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
33
}
34

35
namespace
36
{
37

38
/** Aggregate function, that takes two arguments: value and position,
39
  *  and as a result, builds an array with values are located at corresponding positions.
40
  *
41
  * If more than one value was inserted to single position, the any value (first in case of single thread) is stored.
42
  * If no values was inserted to some position, then default value will be substituted.
43
  *
44
  * Aggregate function also accept optional parameters:
45
  * - default value to substitute;
46
  * - length to resize result arrays (if you want to have results of same length for all aggregation keys);
47
  *
48
  * If you want to pass length, default value should be also given.
49
  */
50

51

52
/// Generic case (inefficient).
53
struct AggregateFunctionGroupArrayInsertAtDataGeneric
54
{
55
    Array value;    /// TODO Add MemoryTracker
56
};
57

58

59
class AggregateFunctionGroupArrayInsertAtGeneric final
60
    : public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>
61
{
62
private:
63
    DataTypePtr type;
64
    SerializationPtr serialization;
65
    Field default_value;
66
    UInt64 length_to_resize = 0;    /// zero means - do not do resizing.
67

68
public:
69
    AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
70
        : IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params, std::make_shared<DataTypeArray>(arguments[0]))
71
        , type(argument_types[0])
72
        , serialization(type->getDefaultSerialization())
73
    {
74
        if (!params.empty())
75
        {
76
            if (params.size() > 2)
77
                throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at most two parameters.", getName());
78

79
            default_value = params[0];
80

81
            if (params.size() == 2)
82
            {
83
                length_to_resize = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
84
                if (length_to_resize > AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
85
                    throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
86
                                    "Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
87
            }
88
        }
89

90
        if (!isUInt(arguments[1]))
91
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of aggregate function {} must be unsigned integer.", getName());
92

93
        if (default_value.isNull())
94
            default_value = type->getDefault();
95
        else
96
        {
97
            Field converted = convertFieldToType(default_value, *type);
98
            if (converted.isNull())
99
                throw Exception(ErrorCodes::CANNOT_CONVERT_TYPE, "Cannot convert parameter of aggregate function {} ({}) "
100
                                "to type {} to be used as default value in array",
101
                                getName(), applyVisitor(FieldVisitorToString(), default_value), type->getName());
102

103
            default_value = converted;
104
        }
105
    }
106

107
    String getName() const override { return "groupArrayInsertAt"; }
108

109
    bool allocatesMemoryInArena() const override { return false; }
110

111
    void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
112
    {
113
        /// TODO Do positions need to be 1-based for this function?
114
        size_t position = columns[1]->getUInt(row_num);
115

116
        /// If position is larger than size to which array will be cut - simply ignore value.
117
        if (length_to_resize && position >= length_to_resize)
118
            return;
119

120
        if (position >= AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
121
            throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size: "
122
                "position argument ({}) is greater or equals to limit ({})",
123
                position, AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
124

125
        Array & arr = data(place).value;
126

127
        if (arr.size() <= position)
128
            arr.resize(position + 1);
129
        else if (!arr[position].isNull())
130
            return; /// Element was already inserted to the specified position.
131

132
        columns[0]->get(row_num, arr[position]);
133
    }
134

135
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
136
    {
137
        Array & arr_lhs = data(place).value;
138
        const Array & arr_rhs = data(rhs).value;
139

140
        if (arr_lhs.size() < arr_rhs.size())
141
            arr_lhs.resize(arr_rhs.size());
142

143
        for (size_t i = 0, size = arr_rhs.size(); i < size; ++i)
144
            if (arr_lhs[i].isNull() && !arr_rhs[i].isNull())
145
                arr_lhs[i] = arr_rhs[i];
146
    }
147

148
    void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
149
    {
150
        const Array & arr = data(place).value;
151
        size_t size = arr.size();
152
        writeVarUInt(size, buf);
153

154
        for (const Field & elem : arr)
155
        {
156
            if (elem.isNull())
157
            {
158
                writeBinary(UInt8(1), buf);
159
            }
160
            else
161
            {
162
                writeBinary(UInt8(0), buf);
163
                serialization->serializeBinary(elem, buf, {});
164
            }
165
        }
166
    }
167

168
    void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
169
    {
170
        size_t size = 0;
171
        readVarUInt(size, buf);
172

173
        if (size > AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
174
            throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
175
                            "Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
176

177
        Array & arr = data(place).value;
178

179
        arr.resize(size);
180
        for (size_t i = 0; i < size; ++i)
181
        {
182
            UInt8 is_null = 0;
183
            readBinary(is_null, buf);
184
            if (!is_null)
185
                serialization->deserializeBinary(arr[i], buf, {});
186
        }
187
    }
188

189
    void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
190
    {
191
        ColumnArray & to_array = assert_cast<ColumnArray &>(to);
192
        IColumn & to_data = to_array.getData();
193
        ColumnArray::Offsets & to_offsets = to_array.getOffsets();
194

195
        const Array & arr = data(place).value;
196

197
        for (const Field & elem : arr)
198
        {
199
            if (!elem.isNull())
200
                to_data.insert(elem);
201
            else
202
                to_data.insert(default_value);
203
        }
204

205
        size_t result_array_size = length_to_resize ? length_to_resize : arr.size();
206

207
        /// Pad array if need.
208
        for (size_t i = arr.size(); i < result_array_size; ++i)
209
            to_data.insert(default_value);
210

211
        to_offsets.push_back(to_offsets.back() + result_array_size);
212
    }
213
};
214

215

216
AggregateFunctionPtr createAggregateFunctionGroupArrayInsertAt(
217
    const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
218
{
219
    assertBinary(name, argument_types);
220

221
    if (argument_types.size() != 2)
222
        throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function groupArrayInsertAt requires two arguments.");
223

224
    return std::make_shared<AggregateFunctionGroupArrayInsertAtGeneric>(argument_types, parameters);
225
}
226

227
}
228

229
void registerAggregateFunctionGroupArrayInsertAt(AggregateFunctionFactory & factory)
230
{
231
    factory.registerFunction("groupArrayInsertAt", createAggregateFunctionGroupArrayInsertAt);
232
}
233

234
}
235

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.