ClickHouse
234 строки · 8.1 Кб
1#include <AggregateFunctions/AggregateFunctionFactory.h>
2#include <AggregateFunctions/FactoryHelpers.h>
3
4#include <IO/WriteHelpers.h>
5#include <IO/ReadHelpers.h>
6
7#include <DataTypes/DataTypeArray.h>
8#include <DataTypes/DataTypesNumber.h>
9
10#include <Columns/ColumnArray.h>
11
12#include <Common/FieldVisitorToString.h>
13#include <Common/FieldVisitorConvertToNumber.h>
14#include <Common/assert_cast.h>
15#include <Interpreters/convertFieldToType.h>
16
17#include <AggregateFunctions/IAggregateFunction.h>
18
19#define AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE 0xFFFFFF
20
21
22namespace DB
23{
24
25struct Settings;
26
27namespace ErrorCodes
28{
29extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
30extern const int TOO_LARGE_ARRAY_SIZE;
31extern const int CANNOT_CONVERT_TYPE;
32extern const int ILLEGAL_TYPE_OF_ARGUMENT;
33}
34
35namespace
36{
37
38/** Aggregate function, that takes two arguments: value and position,
39* and as a result, builds an array with values are located at corresponding positions.
40*
41* If more than one value was inserted to single position, the any value (first in case of single thread) is stored.
42* If no values was inserted to some position, then default value will be substituted.
43*
44* Aggregate function also accept optional parameters:
45* - default value to substitute;
46* - length to resize result arrays (if you want to have results of same length for all aggregation keys);
47*
48* If you want to pass length, default value should be also given.
49*/
50
51
52/// Generic case (inefficient).
53struct AggregateFunctionGroupArrayInsertAtDataGeneric
54{
55Array value; /// TODO Add MemoryTracker
56};
57
58
59class AggregateFunctionGroupArrayInsertAtGeneric final
60: public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>
61{
62private:
63DataTypePtr type;
64SerializationPtr serialization;
65Field default_value;
66UInt64 length_to_resize = 0; /// zero means - do not do resizing.
67
68public:
69AggregateFunctionGroupArrayInsertAtGeneric(const DataTypes & arguments, const Array & params)
70: IAggregateFunctionDataHelper<AggregateFunctionGroupArrayInsertAtDataGeneric, AggregateFunctionGroupArrayInsertAtGeneric>(arguments, params, std::make_shared<DataTypeArray>(arguments[0]))
71, type(argument_types[0])
72, serialization(type->getDefaultSerialization())
73{
74if (!params.empty())
75{
76if (params.size() > 2)
77throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires at most two parameters.", getName());
78
79default_value = params[0];
80
81if (params.size() == 2)
82{
83length_to_resize = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
84if (length_to_resize > AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
85throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
86"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
87}
88}
89
90if (!isUInt(arguments[1]))
91throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument of aggregate function {} must be unsigned integer.", getName());
92
93if (default_value.isNull())
94default_value = type->getDefault();
95else
96{
97Field converted = convertFieldToType(default_value, *type);
98if (converted.isNull())
99throw Exception(ErrorCodes::CANNOT_CONVERT_TYPE, "Cannot convert parameter of aggregate function {} ({}) "
100"to type {} to be used as default value in array",
101getName(), applyVisitor(FieldVisitorToString(), default_value), type->getName());
102
103default_value = converted;
104}
105}
106
107String getName() const override { return "groupArrayInsertAt"; }
108
109bool allocatesMemoryInArena() const override { return false; }
110
111void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
112{
113/// TODO Do positions need to be 1-based for this function?
114size_t position = columns[1]->getUInt(row_num);
115
116/// If position is larger than size to which array will be cut - simply ignore value.
117if (length_to_resize && position >= length_to_resize)
118return;
119
120if (position >= AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
121throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size: "
122"position argument ({}) is greater or equals to limit ({})",
123position, AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
124
125Array & arr = data(place).value;
126
127if (arr.size() <= position)
128arr.resize(position + 1);
129else if (!arr[position].isNull())
130return; /// Element was already inserted to the specified position.
131
132columns[0]->get(row_num, arr[position]);
133}
134
135void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
136{
137Array & arr_lhs = data(place).value;
138const Array & arr_rhs = data(rhs).value;
139
140if (arr_lhs.size() < arr_rhs.size())
141arr_lhs.resize(arr_rhs.size());
142
143for (size_t i = 0, size = arr_rhs.size(); i < size; ++i)
144if (arr_lhs[i].isNull() && !arr_rhs[i].isNull())
145arr_lhs[i] = arr_rhs[i];
146}
147
148void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
149{
150const Array & arr = data(place).value;
151size_t size = arr.size();
152writeVarUInt(size, buf);
153
154for (const Field & elem : arr)
155{
156if (elem.isNull())
157{
158writeBinary(UInt8(1), buf);
159}
160else
161{
162writeBinary(UInt8(0), buf);
163serialization->serializeBinary(elem, buf, {});
164}
165}
166}
167
168void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
169{
170size_t size = 0;
171readVarUInt(size, buf);
172
173if (size > AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE)
174throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
175"Too large array size (maximum: {})", AGGREGATE_FUNCTION_GROUP_ARRAY_INSERT_AT_MAX_SIZE);
176
177Array & arr = data(place).value;
178
179arr.resize(size);
180for (size_t i = 0; i < size; ++i)
181{
182UInt8 is_null = 0;
183readBinary(is_null, buf);
184if (!is_null)
185serialization->deserializeBinary(arr[i], buf, {});
186}
187}
188
189void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
190{
191ColumnArray & to_array = assert_cast<ColumnArray &>(to);
192IColumn & to_data = to_array.getData();
193ColumnArray::Offsets & to_offsets = to_array.getOffsets();
194
195const Array & arr = data(place).value;
196
197for (const Field & elem : arr)
198{
199if (!elem.isNull())
200to_data.insert(elem);
201else
202to_data.insert(default_value);
203}
204
205size_t result_array_size = length_to_resize ? length_to_resize : arr.size();
206
207/// Pad array if need.
208for (size_t i = arr.size(); i < result_array_size; ++i)
209to_data.insert(default_value);
210
211to_offsets.push_back(to_offsets.back() + result_array_size);
212}
213};
214
215
216AggregateFunctionPtr createAggregateFunctionGroupArrayInsertAt(
217const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
218{
219assertBinary(name, argument_types);
220
221if (argument_types.size() != 2)
222throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function groupArrayInsertAt requires two arguments.");
223
224return std::make_shared<AggregateFunctionGroupArrayInsertAtGeneric>(argument_types, parameters);
225}
226
227}
228
229void registerAggregateFunctionGroupArrayInsertAt(AggregateFunctionFactory & factory)
230{
231factory.registerFunction("groupArrayInsertAt", createAggregateFunctionGroupArrayInsertAt);
232}
233
234}
235