ClickHouse

Форк
0
/
catboostEvaluate.cpp 
182 строки · 6.5 Кб
1
#include <Functions/FunctionHelpers.h>
2
#include <Functions/FunctionFactory.h>
3

4
#include <BridgeHelper/CatBoostLibraryBridgeHelper.h>
5
#include <BridgeHelper/IBridgeHelper.h>
6
#include <Columns/ColumnNullable.h>
7
#include <Columns/ColumnString.h>
8
#include <Columns/ColumnTuple.h>
9
#include <Columns/ColumnsNumber.h>
10
#include <Common/assert_cast.h>
11
#include <DataTypes/DataTypeNullable.h>
12
#include <DataTypes/DataTypeTuple.h>
13
#include <DataTypes/DataTypesNumber.h>
14
#include <Functions/IFunction.h>
15
#include <Interpreters/Context.h>
16
#include <Interpreters/Context_fwd.h>
17

18

19
namespace DB
20
{
21

22
namespace ErrorCodes
23
{
24
    extern const int FILE_DOESNT_EXIST;
25
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
26
    extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
27
    extern const int ILLEGAL_COLUMN;
28
}
29

30
/// Evaluate CatBoost model.
31
/// - Arguments: float features first, then categorical features.
32
/// - Result: Float64.
33
class FunctionCatBoostEvaluate final : public IFunction, WithContext
34
{
35
private:
36
    mutable std::unique_ptr<CatBoostLibraryBridgeHelper> bridge_helper;
37

38
public:
39
    static constexpr auto name = "catboostEvaluate";
40

41
    static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionCatBoostEvaluate>(context_); }
42

43
    explicit FunctionCatBoostEvaluate(ContextPtr context_) : WithContext(context_) {}
44
    String getName() const override { return name; }
45
    bool isVariadic() const override { return true; }
46
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
47
    bool isDeterministic() const override { return false; }
48
    bool useDefaultImplementationForNulls() const override { return false; }
49
    size_t getNumberOfArguments() const override { return 0; }
50

51
    void initBridge(const ColumnConst * name_col) const
52
    {
53
        String library_path = getContext()->getConfigRef().getString("catboost_lib_path");
54
        if (!std::filesystem::exists(library_path))
55
            throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load library {}: file doesn't exist", library_path);
56

57
        String model_path = name_col->getValue<String>();
58
        if (!std::filesystem::exists(model_path))
59
            throw Exception(ErrorCodes::FILE_DOESNT_EXIST, "Can't load model {}: file doesn't exist", model_path);
60

61
        bridge_helper = std::make_unique<CatBoostLibraryBridgeHelper>(getContext(), model_path, library_path);
62
    }
63

64
    DataTypePtr getReturnTypeFromLibraryBridge() const
65
    {
66
        size_t tree_count = bridge_helper->getTreeCount();
67
        auto type = std::make_shared<DataTypeFloat64>();
68
        if (tree_count == 1)
69
            return type;
70

71
        DataTypes types(tree_count, type);
72

73
        return std::make_shared<DataTypeTuple>(types);
74
    }
75

76
    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
77
    {
78
        if (arguments.size() < 2)
79
            throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} expects at least 2 arguments", getName());
80

81
        if (!isString(arguments[0].type))
82
            throw Exception(
83
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
84
                "Illegal type {} of first argument of function {}, expected a string.", arguments[0].type->getName(), getName());
85

86
        const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
87
        if (!name_col)
88
            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
89

90
        initBridge(name_col);
91

92
        auto type = getReturnTypeFromLibraryBridge();
93

94
        bool has_nullable = false;
95
        for (size_t i = 1; i < arguments.size(); ++i)
96
            has_nullable = has_nullable || arguments[i].type->isNullable();
97

98
        if (has_nullable)
99
        {
100
            if (const auto * tuple = typeid_cast<const DataTypeTuple *>(type.get()))
101
            {
102
                auto elements = tuple->getElements();
103
                for (auto & element : elements)
104
                    element = makeNullable(element);
105

106
                type = std::make_shared<DataTypeTuple>(elements);
107
            }
108
            else
109
                type = makeNullable(type);
110
        }
111

112
        return type;
113
    }
114

115
    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
116
    {
117
        const auto * name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
118
        if (!name_col)
119
            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of function {} must be a constant string", getName());
120

121
        ColumnRawPtrs column_ptrs;
122
        Columns materialized_columns;
123
        ColumnPtr null_map;
124

125
        ColumnsWithTypeAndName feature_arguments(arguments.begin() + 1, arguments.end());
126
        for (auto & arg : feature_arguments)
127
        {
128
            if (auto full_column = arg.column->convertToFullColumnIfConst())
129
            {
130
                materialized_columns.push_back(full_column);
131
                arg.column = full_column;
132
            }
133
            if (const auto * col_nullable = checkAndGetColumn<ColumnNullable>(&*arg.column))
134
            {
135
                if (!null_map)
136
                    null_map = col_nullable->getNullMapColumnPtr();
137
                else
138
                {
139
                    auto mut_null_map = IColumn::mutate(std::move(null_map));
140

141
                    NullMap & result_null_map = assert_cast<ColumnUInt8 &>(*mut_null_map).getData();
142
                    const NullMap & src_null_map = col_nullable->getNullMapColumn().getData();
143

144
                    for (size_t i = 0, size = result_null_map.size(); i < size; ++i)
145
                        if (src_null_map[i])
146
                            result_null_map[i] = 1;
147

148
                    null_map = std::move(mut_null_map);
149
                }
150

151
                arg.column = col_nullable->getNestedColumn().getPtr();
152
                arg.type = static_cast<const DataTypeNullable &>(*arg.type).getNestedType();
153
            }
154
        }
155

156
        auto res = bridge_helper->evaluate(feature_arguments);
157

158
        if (null_map)
159
        {
160
            if (const auto * tuple = typeid_cast<const ColumnTuple *>(res.get()))
161
            {
162
                auto nested = tuple->getColumns();
163
                for (auto & col : nested)
164
                    col = ColumnNullable::create(col, null_map);
165

166
                res = ColumnTuple::create(nested);
167
            }
168
            else
169
                res = ColumnNullable::create(res, null_map);
170
        }
171

172
        return res;
173
    }
174
};
175

176

177
REGISTER_FUNCTION(CatBoostEvaluate)
178
{
179
    factory.registerFunction<FunctionCatBoostEvaluate>();
180
}
181

182
}
183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.