ClickHouse
147 строк · 5.7 Кб
1#include <Functions/FunctionFactory.h>2#include <Functions/CastOverloadResolver.h>3#include <Functions/FunctionHelpers.h>4#include <DataTypes/DataTypeFactory.h>5#include <DataTypes/DataTypeNullable.h>6#include <Columns/ColumnString.h>7#include <Interpreters/parseColumnsListForTableFunction.h>8#include <Interpreters/Context.h>9
10
11namespace DB12{
13
14namespace ErrorCodes15{
16extern const int ILLEGAL_TYPE_OF_ARGUMENT;17}
18
19FunctionBasePtr createFunctionBaseCast(20ContextPtr context,21const char * name,22const ColumnsWithTypeAndName & arguments,23const DataTypePtr & return_type,24std::optional<CastDiagnostic> diagnostic,25CastType cast_type);26
27
28/** CastInternal does not preserve nullability of the data type,
29* i.e. CastInternal(toNullable(toInt8(1)) as Int32) will be Int32(1).
30*
31* Cast preserves nullability according to setting `cast_keep_nullable`,
32* i.e. Cast(toNullable(toInt8(1)) as Int32) will be Nullable(Int32(1)) if `cast_keep_nullable` == 1.
33*/
34class CastOverloadResolverImpl : public IFunctionOverloadResolver35{
36public:37const char * getNameImpl() const38{39if (cast_type == CastType::accurate)40return "accurateCast";41if (cast_type == CastType::accurateOrNull)42return "accurateCastOrNull";43if (internal)44return "_CAST";45else46return "CAST";47}48
49String getName() const override50{51return getNameImpl();52}53
54size_t getNumberOfArguments() const override { return 2; }55
56ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }57
58explicit CastOverloadResolverImpl(ContextPtr context_, CastType cast_type_, bool internal_, std::optional<CastDiagnostic> diagnostic_, bool keep_nullable_, const DataTypeValidationSettings & data_type_validation_settings_)59: context(context_)60, cast_type(cast_type_)61, internal(internal_)62, diagnostic(std::move(diagnostic_))63, keep_nullable(keep_nullable_)64, data_type_validation_settings(data_type_validation_settings_)65{66}67
68static FunctionOverloadResolverPtr create(ContextPtr context, CastType cast_type, bool internal, std::optional<CastDiagnostic> diagnostic)69{70if (internal)71{72return std::make_unique<CastOverloadResolverImpl>(context, cast_type, internal, diagnostic, false /*keep_nullable*/, DataTypeValidationSettings{});73}74else75{76const auto & settings_ref = context->getSettingsRef();77return std::make_unique<CastOverloadResolverImpl>(context, cast_type, internal, diagnostic, settings_ref.cast_keep_nullable, DataTypeValidationSettings(settings_ref));78}79}80
81protected:82FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override83{84return createFunctionBaseCast(context, getNameImpl(), arguments, return_type, diagnostic, cast_type);85}86
87DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override88{89const auto & column = arguments.back().column;90if (!column)91throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument to {} must be a constant string describing type. "92"Instead there is non-constant column of type {}", getName(), arguments.back().type->getName());93
94const auto * type_col = checkAndGetColumnConst<ColumnString>(column.get());95if (!type_col)96throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument to {} must be a constant string describing type. "97"Instead there is a column with the following structure: {}", getName(), column->dumpStructure());98
99DataTypePtr type = DataTypeFactory::instance().get(type_col->getValue<String>());100validateDataType(type, data_type_validation_settings);101
102if (cast_type == CastType::accurateOrNull)103{104/// Variant handles NULLs by itself during conversions.105if (!isVariant(type))106return makeNullable(type);107}108
109if (internal)110return type;111
112if (keep_nullable && arguments.front().type->isNullable() && type->canBeInsideNullable())113return makeNullable(type);114
115return type;116}117
118bool useDefaultImplementationForNulls() const override { return false; }119bool useDefaultImplementationForNothing() const override { return false; }120bool useDefaultImplementationForLowCardinalityColumns() const override { return false; }121
122private:123ContextPtr context;124CastType cast_type;125bool internal;126std::optional<CastDiagnostic> diagnostic;127bool keep_nullable;128DataTypeValidationSettings data_type_validation_settings;129};130
131
132FunctionOverloadResolverPtr createInternalCastOverloadResolver(CastType type, std::optional<CastDiagnostic> diagnostic)133{
134return CastOverloadResolverImpl::create(ContextPtr{}, type, true, diagnostic);135}
136
137REGISTER_FUNCTION(CastOverloadResolvers)138{
139factory.registerFunction("_CAST", [](ContextPtr context){ return CastOverloadResolverImpl::create(context, CastType::nonAccurate, true, {}); }, {}, FunctionFactory::CaseInsensitive);140/// Note: "internal" (not affected by null preserving setting) versions of accurate cast functions are unneeded.141
142factory.registerFunction("CAST", [](ContextPtr context){ return CastOverloadResolverImpl::create(context, CastType::nonAccurate, false, {}); }, {}, FunctionFactory::CaseInsensitive);143factory.registerFunction("accurateCast", [](ContextPtr context){ return CastOverloadResolverImpl::create(context, CastType::accurate, false, {}); }, {});144factory.registerFunction("accurateCastOrNull", [](ContextPtr context){ return CastOverloadResolverImpl::create(context, CastType::accurateOrNull, false, {}); }, {});145}
146
147}
148