ClickHouse

Форк
0
/
variantElement.cpp 
238 строк · 11.7 Кб
1
#include <Functions/IFunction.h>
2
#include <Functions/FunctionFactory.h>
3
#include <Functions/FunctionHelpers.h>
4
#include <DataTypes/IDataType.h>
5
#include <DataTypes/DataTypeArray.h>
6
#include <DataTypes/DataTypeVariant.h>
7
#include <DataTypes/DataTypeFactory.h>
8
#include <Columns/ColumnArray.h>
9
#include <Columns/ColumnString.h>
10
#include <Columns/ColumnVariant.h>
11
#include <Columns/ColumnNullable.h>
12
#include <Columns/ColumnLowCardinality.h>
13
#include <Common/assert_cast.h>
14
#include <memory>
15

16

17
namespace DB
18
{
19

20
namespace ErrorCodes
21
{
22
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
23
    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
24
}
25

26
namespace
27
{
28

29
/** Extract element of Variant by variant type name.
30
  * Also the function looks through Arrays: you can get Array of Variant elements from Array of Variants.
31
  */
32
class FunctionVariantElement : public IFunction
33
{
34
public:
35
    static constexpr auto name = "variantElement";
36

37
    static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionVariantElement>(); }
38
    String getName() const override { return name; }
39
    bool isVariadic() const override { return true; }
40
    size_t getNumberOfArguments() const override { return 0; }
41
    bool useDefaultImplementationForConstants() const override { return true; }
42
    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
43
    bool useDefaultImplementationForNulls() const override { return false; }
44
    bool useDefaultImplementationForLowCardinalityColumns() const override { return false; }
45
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
46

47
    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
48
    {
49
        const size_t number_of_arguments = arguments.size();
50

51
        if (number_of_arguments < 2 || number_of_arguments > 3)
52
            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
53
                            "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
54
                            getName(), number_of_arguments);
55

56
        size_t count_arrays = 0;
57
        const IDataType * input_type = arguments[0].type.get();
58
        while (const DataTypeArray * array = checkAndGetDataType<DataTypeArray>(input_type))
59
        {
60
            input_type = array->getNestedType().get();
61
            ++count_arrays;
62
        }
63

64
        const DataTypeVariant * variant_type = checkAndGetDataType<DataTypeVariant>(input_type);
65
        if (!variant_type)
66
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
67
                    "First argument for function {} must be Variant or Array of Variant. Actual {}",
68
                    getName(),
69
                    arguments[0].type->getName());
70

71
        std::optional<size_t> variant_global_discr = getVariantGlobalDiscriminator(arguments[1].column, *variant_type, number_of_arguments);
72
        if (variant_global_discr.has_value())
73
        {
74
            DataTypePtr return_type = makeNullableOrLowCardinalityNullableSafe(variant_type->getVariant(variant_global_discr.value()));
75

76
            for (; count_arrays; --count_arrays)
77
                return_type = std::make_shared<DataTypeArray>(return_type);
78

79
            return return_type;
80
        }
81
        else
82
            return arguments[2].type;
83
    }
84

85
    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
86
    {
87
        const auto & input_arg = arguments[0];
88
        const IDataType * input_type = input_arg.type.get();
89
        const IColumn * input_col = input_arg.column.get();
90

91
        bool input_arg_is_const = false;
92
        if (typeid_cast<const ColumnConst *>(input_col))
93
        {
94
            input_col = assert_cast<const ColumnConst *>(input_col)->getDataColumnPtr().get();
95
            input_arg_is_const = true;
96
        }
97

98
        Columns array_offsets;
99
        while (const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(input_type))
100
        {
101
            const ColumnArray * array_col = assert_cast<const ColumnArray *>(input_col);
102

103
            input_type = array_type->getNestedType().get();
104
            input_col = &array_col->getData();
105
            array_offsets.push_back(array_col->getOffsetsPtr());
106
        }
107

108
        const DataTypeVariant * input_type_as_variant = checkAndGetDataType<DataTypeVariant>(input_type);
109
        const ColumnVariant * input_col_as_variant = checkAndGetColumn<ColumnVariant>(input_col);
110
        if (!input_type_as_variant || !input_col_as_variant)
111
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
112
                            "First argument for function {} must be Variant or array of Variants. Actual {}", getName(), input_arg.type->getName());
113

114
        std::optional<size_t> variant_global_discr = getVariantGlobalDiscriminator(arguments[1].column, *input_type_as_variant, arguments.size());
115

116
        if (!variant_global_discr.has_value())
117
            return arguments[2].column;
118

119
        const auto & variant_type = input_type_as_variant->getVariant(*variant_global_discr);
120
        const auto & variant_column = input_col_as_variant->getVariantPtrByGlobalDiscriminator(*variant_global_discr);
121

122
        /// If Variant has only NULLs or our variant doesn't have any real values,
123
        /// just create column with default values and create null mask with 1.
124
        if (input_col_as_variant->hasOnlyNulls() || variant_column->empty())
125
        {
126
            auto res = variant_type->createColumn();
127

128
            if (variant_type->lowCardinality())
129
                assert_cast<ColumnLowCardinality &>(*res).nestedToNullable();
130

131
            res->insertManyDefaults(input_col_as_variant->size());
132
            if (!variant_type->canBeInsideNullable())
133
                return wrapInArraysAndConstIfNeeded(std::move(res), array_offsets, input_arg_is_const, input_rows_count);
134

135
            auto null_map = ColumnUInt8::create();
136
            auto & null_map_data = null_map->getData();
137
            null_map_data.resize_fill(input_col_as_variant->size(), 1);
138
            return wrapInArraysAndConstIfNeeded(ColumnNullable::create(std::move(res), std::move(null_map)), array_offsets, input_arg_is_const, input_rows_count);
139
        }
140

141
        /// If we extract single non-empty column and have no NULLs, then just return this variant.
142
        if (auto non_empty_local_discr = input_col_as_variant->getLocalDiscriminatorOfOneNoneEmptyVariantNoNulls())
143
        {
144
            /// If we were trying to extract some other variant,
145
            /// it would be empty and we would already processed this case above.
146
            chassert(input_col_as_variant->globalDiscriminatorByLocal(*non_empty_local_discr) == variant_global_discr);
147
            return wrapInArraysAndConstIfNeeded(makeNullableOrLowCardinalityNullableSafe(variant_column), array_offsets, input_arg_is_const, input_rows_count);
148
        }
149

150
        /// In general case we should calculate null-mask for variant
151
        /// according to the discriminators column and expand
152
        /// variant column by this mask to get a full column (with default values on NULLs)
153
        const auto & local_discriminators = input_col_as_variant->getLocalDiscriminators();
154
        auto null_map = ColumnUInt8::create();
155
        auto & null_map_data = null_map->getData();
156
        null_map_data.reserve(local_discriminators.size());
157
        auto variant_local_discr = input_col_as_variant->localDiscriminatorByGlobal(*variant_global_discr);
158
        for (auto local_discr : local_discriminators)
159
            null_map_data.push_back(local_discr != variant_local_discr);
160

161
        auto expanded_variant_column = IColumn::mutate(variant_column);
162
        if (variant_type->lowCardinality())
163
            expanded_variant_column = assert_cast<ColumnLowCardinality &>(*expanded_variant_column).cloneNullable();
164
        expanded_variant_column->expand(null_map_data, /*inverted = */ true);
165
        if (variant_type->canBeInsideNullable())
166
            return wrapInArraysAndConstIfNeeded(ColumnNullable::create(std::move(expanded_variant_column), std::move(null_map)), array_offsets, input_arg_is_const, input_rows_count);
167
        return wrapInArraysAndConstIfNeeded(std::move(expanded_variant_column), array_offsets, input_arg_is_const, input_rows_count);
168
    }
169
private:
170
    std::optional<size_t> getVariantGlobalDiscriminator(const ColumnPtr & index_column, const DataTypeVariant & variant_type, size_t argument_size) const
171
    {
172
        const auto * name_col = checkAndGetColumnConst<ColumnString>(index_column.get());
173
        if (!name_col)
174
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
175
                            "Second argument to {} with Variant argument must be a constant String",
176
                            getName());
177

178
        String variant_element_name = name_col->getValue<String>();
179
        auto variant_element_type = DataTypeFactory::instance().tryGet(variant_element_name);
180
        if (variant_element_type)
181
        {
182
            const auto & variants = variant_type.getVariants();
183
            for (size_t i = 0; i != variants.size(); ++i)
184
            {
185
                if (variants[i]->getName() == variant_element_type->getName())
186
                    return i;
187
            }
188
        }
189

190
        if (argument_size == 2)
191
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} doesn't contain variant with type {}", variant_type.getName(), variant_element_name);
192
        return std::nullopt;
193
    }
194

195
    ColumnPtr wrapInArraysAndConstIfNeeded(ColumnPtr res, const Columns & array_offsets, bool input_arg_is_const, size_t input_rows_count) const
196
    {
197
        for (auto it = array_offsets.rbegin(); it != array_offsets.rend(); ++it)
198
            res = ColumnArray::create(res, *it);
199

200
        if (input_arg_is_const)
201
            res = ColumnConst::create(res, input_rows_count);
202

203
        return res;
204
    }
205
};
206

207
}
208

209
REGISTER_FUNCTION(VariantElement)
210
{
211
    factory.registerFunction<FunctionVariantElement>(FunctionDocumentation{
212
        .description = R"(
213
Extracts a column with specified type from a `Variant` column.
214
)",
215
        .syntax{"variantElement(variant, type_name, [, default_value])"},
216
        .arguments{{
217
            {"variant", "Variant column"},
218
            {"type_name", "The name of the variant type to extract"},
219
            {"default_value", "The default value that will be used if variant doesn't have variant with specified type. Can be any type. Optional"}}},
220
        .examples{{{
221
            "Example",
222
            R"(
223
CREATE TABLE test (v Variant(UInt64, String, Array(UInt64))) ENGINE = Memory;
224
INSERT INTO test VALUES (NULL), (42), ('Hello, World!'), ([1, 2, 3]);
225
SELECT v, variantElement(v, 'String'), variantElement(v, 'UInt64'), variantElement(v, 'Array(UInt64)') FROM test;)",
226
            R"(
227
┌─v─────────────┬─variantElement(v, 'String')─┬─variantElement(v, 'UInt64')─┬─variantElement(v, 'Array(UInt64)')─┐
228
│ ᴺᵁᴸᴸ          │ ᴺᵁᴸᴸ                        │                        ᴺᵁᴸᴸ │ []                                 │
229
│ 42            │ ᴺᵁᴸᴸ                        │                          42 │ []                                 │
230
│ Hello, World! │ Hello, World!               │                        ᴺᵁᴸᴸ │ []                                 │
231
│ [1,2,3]       │ ᴺᵁᴸᴸ                        │                        ᴺᵁᴸᴸ │ [1,2,3]                            │
232
└───────────────┴─────────────────────────────┴─────────────────────────────┴────────────────────────────────────┘
233
)"}}},
234
        .categories{"Variant"},
235
    });
236
}
237

238
}
239

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.