ClickHouse

Форк
0
/
FunctionsStringDistance.cpp 
516 строк · 18.6 Кб
1
#include <Columns/ColumnString.h>
2
#include <Columns/ColumnsNumber.h>
3
#include <DataTypes/DataTypeString.h>
4
#include <DataTypes/DataTypesNumber.h>
5
#include <Functions/FunctionFactory.h>
6
#include <Functions/FunctionsStringSimilarity.h>
7
#include <Common/PODArray.h>
8
#include <Common/UTF8Helpers.h>
9
#include <Common/iota.h>
10

11
#include <numeric>
12

13
#ifdef __SSE4_2__
14
#    include <nmmintrin.h>
15
#endif
16

17
namespace DB
18
{
19
namespace ErrorCodes
20
{
21
extern const int BAD_ARGUMENTS;
22
extern const int TOO_LARGE_STRING_SIZE;
23
}
24

25
template <typename Op>
26
struct FunctionStringDistanceImpl
27
{
28
    using ResultType = typename Op::ResultType;
29

30
    static void constantConstant(const String & haystack, const String & needle, ResultType & res)
31
    {
32
        res = Op::process(haystack.data(), haystack.size(), needle.data(), needle.size());
33
    }
34

35
    static void vectorVector(
36
        const ColumnString::Chars & haystack_data,
37
        const ColumnString::Offsets & haystack_offsets,
38
        const ColumnString::Chars & needle_data,
39
        const ColumnString::Offsets & needle_offsets,
40
        PaddedPODArray<ResultType> & res)
41
    {
42
        size_t size = res.size();
43
        const char * haystack = reinterpret_cast<const char *>(haystack_data.data());
44
        const char * needle = reinterpret_cast<const char *>(needle_data.data());
45
        for (size_t i = 0; i < size; ++i)
46
        {
47
            res[i] = Op::process(
48
                haystack + haystack_offsets[i - 1],
49
                haystack_offsets[i] - haystack_offsets[i - 1] - 1,
50
                needle + needle_offsets[i - 1],
51
                needle_offsets[i] - needle_offsets[i - 1] - 1);
52
        }
53
    }
54

55
    static void constantVector(
56
        const String & haystack,
57
        const ColumnString::Chars & needle_data,
58
        const ColumnString::Offsets & needle_offsets,
59
        PaddedPODArray<ResultType> & res)
60
    {
61
        const char * haystack_data = haystack.data();
62
        size_t haystack_size = haystack.size();
63
        const char * needle = reinterpret_cast<const char *>(needle_data.data());
64
        size_t size = res.size();
65
        for (size_t i = 0; i < size; ++i)
66
        {
67
            res[i] = Op::process(haystack_data, haystack_size,
68
                needle + needle_offsets[i - 1], needle_offsets[i] - needle_offsets[i - 1] - 1);
69
        }
70
    }
71

72
    static void vectorConstant(
73
        const ColumnString::Chars & data,
74
        const ColumnString::Offsets & offsets,
75
        const String & needle,
76
        PaddedPODArray<ResultType> & res)
77
    {
78
        constantVector(needle, data, offsets, res);
79
    }
80

81
};
82

83
struct ByteHammingDistanceImpl
84
{
85
    using ResultType = UInt64;
86
    static ResultType process(
87
        const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
88
    {
89
        UInt64 res = 0;
90
        const char * haystack_end = haystack + haystack_size;
91
        const char * needle_end = needle + needle_size;
92

93
#ifdef __SSE4_2__
94
        static constexpr auto mode = _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_EACH | _SIDD_NEGATIVE_POLARITY;
95

96
        const char * haystack_end16 = haystack + haystack_size / 16 * 16;
97
        const char * needle_end16 = needle + needle_size / 16 * 16;
98

99
        for (; haystack < haystack_end16 && needle < needle_end16; haystack += 16, needle += 16)
100
        {
101
            __m128i s1 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(haystack));
102
            __m128i s2 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(needle));
103
            auto result_mask = _mm_cmpestrm(s1, 16, s2, 16, mode);
104
            const __m128i mask_hi = _mm_unpackhi_epi64(result_mask, result_mask);
105
            res += _mm_popcnt_u64(_mm_cvtsi128_si64(result_mask)) + _mm_popcnt_u64(_mm_cvtsi128_si64(mask_hi));
106
        }
107
#endif
108
        for (; haystack != haystack_end && needle != needle_end; ++haystack, ++needle)
109
            res += *haystack != *needle;
110

111
        res = res + (haystack_end - haystack) + (needle_end - needle);
112
        return res;
113
    }
114
};
115

116
template <bool is_utf8>
117
struct ByteJaccardIndexImpl
118
{
119
    using ResultType = Float64;
120
    static ResultType process(
121
        const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
122
    {
123
        if (haystack_size == 0 || needle_size == 0)
124
            return 0;
125

126
        const char * haystack_end = haystack + haystack_size;
127
        const char * needle_end = needle + needle_size;
128

129
        /// For byte strings use plain array as a set
130
        constexpr size_t max_size = std::numeric_limits<unsigned char>::max() + 1;
131
        std::array<UInt8, max_size> haystack_set;
132
        std::array<UInt8, max_size> needle_set;
133

134
        /// For UTF-8 strings we also use sets of code points greater than max_size
135
        std::set<UInt32> haystack_utf8_set;
136
        std::set<UInt32> needle_utf8_set;
137

138
        haystack_set.fill(0);
139
        needle_set.fill(0);
140

141
        while (haystack < haystack_end)
142
        {
143
            size_t len = 1;
144
            if constexpr (is_utf8)
145
                len = UTF8::seqLength(*haystack);
146

147
            if (len == 1)
148
            {
149
                haystack_set[static_cast<unsigned char>(*haystack)] = 1;
150
                ++haystack;
151
            }
152
            else
153
            {
154
                auto code_point = UTF8::convertUTF8ToCodePoint(haystack, haystack_end - haystack);
155
                if (code_point.has_value())
156
                {
157
                    haystack_utf8_set.insert(code_point.value());
158
                    haystack += len;
159
                }
160
                else
161
                {
162
                    throw Exception(ErrorCodes::BAD_ARGUMENTS, "Illegal UTF-8 sequence, while processing '{}'", StringRef(haystack, haystack_end - haystack));
163
                }
164
            }
165
        }
166

167
        while (needle < needle_end)
168
        {
169

170
            size_t len = 1;
171
            if constexpr (is_utf8)
172
                len = UTF8::seqLength(*needle);
173

174
            if (len == 1)
175
            {
176
                needle_set[static_cast<unsigned char>(*needle)] = 1;
177
                ++needle;
178
            }
179
            else
180
            {
181
                auto code_point = UTF8::convertUTF8ToCodePoint(needle, needle_end - needle);
182
                if (code_point.has_value())
183
                {
184
                    needle_utf8_set.insert(code_point.value());
185
                    needle += len;
186
                }
187
                else
188
                {
189
                    throw Exception(ErrorCodes::BAD_ARGUMENTS, "Illegal UTF-8 sequence, while processing '{}'", StringRef(needle, needle_end - needle));
190
                }
191
            }
192
        }
193

194
        UInt8 intersection = 0;
195
        UInt8 union_size = 0;
196

197
        if constexpr (is_utf8)
198
        {
199
            auto lit = haystack_utf8_set.begin();
200
            auto rit = needle_utf8_set.begin();
201
            while (lit != haystack_utf8_set.end() && rit != needle_utf8_set.end())
202
            {
203
                if (*lit == *rit)
204
                {
205
                    ++intersection;
206
                    ++lit;
207
                    ++rit;
208
                }
209
                else if (*lit < *rit)
210
                    ++lit;
211
                else
212
                    ++rit;
213
            }
214
            union_size = haystack_utf8_set.size() + needle_utf8_set.size() - intersection;
215
        }
216

217
        for (size_t i = 0; i < max_size; ++i)
218
        {
219
            intersection += haystack_set[i] & needle_set[i];
220
            union_size += haystack_set[i] | needle_set[i];
221
        }
222

223
        return static_cast<ResultType>(intersection) / static_cast<ResultType>(union_size);
224
    }
225
};
226

227
static constexpr size_t max_string_size = 1u << 16;
228

229
struct ByteEditDistanceImpl
230
{
231
    using ResultType = UInt64;
232

233
    static ResultType process(
234
        const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
235
    {
236
        if (haystack_size == 0 || needle_size == 0)
237
            return haystack_size + needle_size;
238

239
        /// Safety threshold against DoS, since we use two arrays to calculate the distance.
240
        if (haystack_size > max_string_size || needle_size > max_string_size)
241
            throw Exception(
242
                ErrorCodes::TOO_LARGE_STRING_SIZE,
243
                "The string size is too big for function editDistance, should be at most {}", max_string_size);
244

245
        PaddedPODArray<ResultType> distances0(haystack_size + 1, 0);
246
        PaddedPODArray<ResultType> distances1(haystack_size + 1, 0);
247

248
        ResultType substitution = 0;
249
        ResultType insertion = 0;
250
        ResultType deletion = 0;
251

252
        iota(distances0.data(), haystack_size + 1, ResultType(0));
253

254
        for (size_t pos_needle = 0; pos_needle < needle_size; ++pos_needle)
255
        {
256
            distances1[0] = pos_needle + 1;
257

258
            for (size_t pos_haystack = 0; pos_haystack < haystack_size; pos_haystack++)
259
            {
260
                deletion = distances0[pos_haystack + 1] + 1;
261
                insertion = distances1[pos_haystack] + 1;
262
                substitution = distances0[pos_haystack];
263

264
                if (*(needle + pos_needle) != *(haystack + pos_haystack))
265
                    substitution += 1;
266

267
                distances1[pos_haystack + 1] = std::min(deletion, std::min(substitution, insertion));
268
            }
269
            distances0.swap(distances1);
270
        }
271

272
        return distances0[haystack_size];
273
    }
274
};
275

276
struct ByteDamerauLevenshteinDistanceImpl
277
{
278
    using ResultType = UInt64;
279

280
    static ResultType process(
281
        const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
282
    {
283
        /// Safety threshold against DoS
284
        if (haystack_size > max_string_size || needle_size > max_string_size)
285
            throw Exception(
286
                ErrorCodes::TOO_LARGE_STRING_SIZE,
287
                "The string size is too big for function damerauLevenshteinDistance, should be at most {}", max_string_size);
288

289
        /// Shortcuts:
290

291
        if (haystack_size == 0)
292
            return needle_size;
293

294
        if (needle_size == 0)
295
            return haystack_size;
296

297
        if (haystack_size == needle_size && memcmp(haystack, needle, haystack_size) == 0)
298
            return 0;
299

300
        /// Implements the algorithm for optimal string alignment distance from
301
        /// https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance#Optimal_string_alignment_distance
302

303
        /// Dynamically allocate memory for the 2D array
304
        /// Allocating a 2D array, for convenience starts is an array of pointers to the start of the rows.
305
        std::vector<int> d((needle_size + 1) * (haystack_size + 1));
306
        std::vector<int *> starts(haystack_size + 1);
307

308
        /// Setting the pointers in starts to the beginning of (needle_size + 1)-long intervals.
309
        /// Also initialize the row values based on the mentioned algorithm.
310
        for (size_t i = 0; i <= haystack_size; ++i)
311
        {
312
            starts[i] = d.data() + (needle_size + 1) * i;
313
            starts[i][0] = static_cast<int>(i);
314
        }
315

316
        for (size_t j = 0; j <= needle_size; ++j)
317
        {
318
            starts[0][j] = static_cast<int>(j);
319
        }
320

321
        for (size_t i = 1; i <= haystack_size; ++i)
322
        {
323
            for (size_t j = 1; j <= needle_size; ++j)
324
            {
325
                int cost = (haystack[i - 1] == needle[j - 1]) ? 0 : 1;
326
                starts[i][j] = std::min(starts[i - 1][j] + 1,                  /// deletion
327
                                        std::min(starts[i][j - 1] + 1,         /// insertion
328
                                                 starts[i - 1][j - 1] + cost)  /// substitution
329
                               );
330
                if (i > 1 && j > 1 && haystack[i - 1] == needle[j - 2] && haystack[i - 2] == needle[j - 1])
331
                    starts[i][j] = std::min(starts[i][j], starts[i - 2][j - 2] + 1); /// transposition
332
            }
333
        }
334

335
        return starts[haystack_size][needle_size];
336
    }
337
};
338

339
struct ByteJaroSimilarityImpl
340
{
341
    using ResultType = Float64;
342

343
    static ResultType process(
344
        const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
345
    {
346
        /// Safety threshold against DoS
347
        if (haystack_size > max_string_size || needle_size > max_string_size)
348
            throw Exception(
349
                ErrorCodes::TOO_LARGE_STRING_SIZE,
350
                "The string size is too big for function jaroSimilarity, should be at most {}", max_string_size);
351

352
        /// Shortcuts:
353

354
        if (haystack_size == 0)
355
            return needle_size;
356

357
        if (needle_size == 0)
358
            return haystack_size;
359

360
        if (haystack_size == needle_size && memcmp(haystack, needle, haystack_size) == 0)
361
            return 1.0;
362

363
        const int s1len = static_cast<int>(haystack_size);
364
        const int s2len = static_cast<int>(needle_size);
365

366
        /// Window size to search for matches in the other string
367
        const int max_range = std::max(0, std::max(s1len, s2len) / 2 - 1);
368
        std::vector<int> s1_matching(s1len, -1);
369
        std::vector<int> s2_matching(s2len, -1);
370

371
        /// Calculate matching characters
372
        size_t matching_characters = 0;
373
        for (int i = 0; i < s1len; i++)
374
        {
375
            /// Matching window
376
            const int min_index = std::max(i - max_range, 0);
377
            const int max_index = std::min(i + max_range + 1, s2len);
378
            for (int j = min_index; j < max_index; j++)
379
            {
380
                if (s2_matching[j] == -1 && haystack[i] == needle[j])
381
                {
382
                    s1_matching[i] = i;
383
                    s2_matching[j] = j;
384
                    matching_characters++;
385
                    break;
386
                }
387
            }
388
        }
389

390
        if (matching_characters == 0)
391
            return 0.0;
392

393
        /// Transpositions (one-way only)
394
        double transpositions = 0.0;
395
        for (size_t i = 0, s1i = 0, s2i = 0; i < matching_characters; i++)
396
        {
397
            while (s1_matching[s1i] == -1)
398
                s1i++;
399
            while (s2_matching[s2i] == -1)
400
                s2i++;
401
            if (haystack[s1i] != needle[s2i])
402
                transpositions += 0.5;
403
            s1i++;
404
            s2i++;
405
        }
406

407
        double m = static_cast<double>(matching_characters);
408
        double jaro_similarity = 1.0 / 3.0  * (m / static_cast<double>(s1len)
409
                                            + m / static_cast<double>(s2len)
410
                                            + (m - transpositions) / m);
411
        return jaro_similarity;
412
    }
413
};
414

415
struct ByteJaroWinklerSimilarityImpl
416
{
417
    using ResultType = Float64;
418

419
    static ResultType process(
420
        const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
421
    {
422
        static constexpr int max_prefix_length = 4;
423
        static constexpr double scaling_factor =  0.1;
424
        static constexpr double boost_threshold = 0.7;
425

426
        /// Safety threshold against DoS
427
        if (haystack_size > max_string_size || needle_size > max_string_size)
428
            throw Exception(
429
                ErrorCodes::TOO_LARGE_STRING_SIZE,
430
                "The string size is too big for function jaroWinklerSimilarity, should be at most {}", max_string_size);
431

432
        const int s1len = static_cast<int>(haystack_size);
433
        const int s2len = static_cast<int>(needle_size);
434

435
        ResultType jaro_winkler_similarity = ByteJaroSimilarityImpl::process(haystack, haystack_size, needle, needle_size);
436

437
        if (jaro_winkler_similarity > boost_threshold)
438
        {
439
            const int common_length = std::min(max_prefix_length, std::min(s1len, s2len));
440
            int common_prefix = 0;
441
            while (common_prefix < common_length && haystack[common_prefix] == needle[common_prefix])
442
                common_prefix++;
443

444
            jaro_winkler_similarity += common_prefix * scaling_factor * (1.0 - jaro_winkler_similarity);
445
        }
446
        return jaro_winkler_similarity;
447
    }
448
};
449

450
struct NameByteHammingDistance
451
{
452
    static constexpr auto name = "byteHammingDistance";
453
};
454
using FunctionByteHammingDistance = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteHammingDistanceImpl>, NameByteHammingDistance>;
455

456
struct NameEditDistance
457
{
458
    static constexpr auto name = "editDistance";
459
};
460
using FunctionEditDistance = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteEditDistanceImpl>, NameEditDistance>;
461

462
struct NameDamerauLevenshteinDistance
463
{
464
    static constexpr auto name = "damerauLevenshteinDistance";
465
};
466
using FunctionDamerauLevenshteinDistance = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteDamerauLevenshteinDistanceImpl>, NameDamerauLevenshteinDistance>;
467

468
struct NameJaccardIndex
469
{
470
    static constexpr auto name = "stringJaccardIndex";
471
};
472
using FunctionStringJaccardIndex = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteJaccardIndexImpl<false>>, NameJaccardIndex>;
473

474
struct NameJaccardIndexUTF8
475
{
476
    static constexpr auto name = "stringJaccardIndexUTF8";
477
};
478
using FunctionStringJaccardIndexUTF8 = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteJaccardIndexImpl<true>>, NameJaccardIndexUTF8>;
479

480
struct NameJaroSimilarity
481
{
482
    static constexpr auto name = "jaroSimilarity";
483
};
484
using FunctionJaroSimilarity = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteJaroSimilarityImpl>, NameJaroSimilarity>;
485

486
struct NameJaroWinklerSimilarity
487
{
488
    static constexpr auto name = "jaroWinklerSimilarity";
489
};
490
using FunctionJaroWinklerSimilarity = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteJaroWinklerSimilarityImpl>, NameJaroWinklerSimilarity>;
491

492
REGISTER_FUNCTION(StringDistance)
493
{
494
    factory.registerFunction<FunctionByteHammingDistance>(
495
        FunctionDocumentation{.description = R"(Calculates Hamming distance between two byte-strings.)"});
496
    factory.registerAlias("mismatches", NameByteHammingDistance::name);
497

498
    factory.registerFunction<FunctionEditDistance>(
499
        FunctionDocumentation{.description = R"(Calculates the edit distance between two byte-strings.)"});
500
    factory.registerAlias("levenshteinDistance", NameEditDistance::name);
501

502
    factory.registerFunction<FunctionDamerauLevenshteinDistance>(
503
        FunctionDocumentation{.description = R"(Calculates the Damerau-Levenshtein distance two between two byte-string.)"});
504

505
    factory.registerFunction<FunctionStringJaccardIndex>(
506
        FunctionDocumentation{.description = R"(Calculates the Jaccard similarity index between two byte strings.)"});
507
    factory.registerFunction<FunctionStringJaccardIndexUTF8>(
508
        FunctionDocumentation{.description = R"(Calculates the Jaccard similarity index between two UTF8 strings.)"});
509

510
    factory.registerFunction<FunctionJaroSimilarity>(
511
        FunctionDocumentation{.description = R"(Calculates the Jaro similarity between two byte-string.)"});
512

513
    factory.registerFunction<FunctionJaroWinklerSimilarity>(
514
        FunctionDocumentation{.description = R"(Calculates the Jaro-Winkler similarity between two byte-string.)"});
515
}
516
}
517

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.