ClickHouse
516 строк · 18.6 Кб
1#include <Columns/ColumnString.h>
2#include <Columns/ColumnsNumber.h>
3#include <DataTypes/DataTypeString.h>
4#include <DataTypes/DataTypesNumber.h>
5#include <Functions/FunctionFactory.h>
6#include <Functions/FunctionsStringSimilarity.h>
7#include <Common/PODArray.h>
8#include <Common/UTF8Helpers.h>
9#include <Common/iota.h>
10
11#include <numeric>
12
13#ifdef __SSE4_2__
14# include <nmmintrin.h>
15#endif
16
17namespace DB
18{
19namespace ErrorCodes
20{
21extern const int BAD_ARGUMENTS;
22extern const int TOO_LARGE_STRING_SIZE;
23}
24
25template <typename Op>
26struct FunctionStringDistanceImpl
27{
28using ResultType = typename Op::ResultType;
29
30static void constantConstant(const String & haystack, const String & needle, ResultType & res)
31{
32res = Op::process(haystack.data(), haystack.size(), needle.data(), needle.size());
33}
34
35static void vectorVector(
36const ColumnString::Chars & haystack_data,
37const ColumnString::Offsets & haystack_offsets,
38const ColumnString::Chars & needle_data,
39const ColumnString::Offsets & needle_offsets,
40PaddedPODArray<ResultType> & res)
41{
42size_t size = res.size();
43const char * haystack = reinterpret_cast<const char *>(haystack_data.data());
44const char * needle = reinterpret_cast<const char *>(needle_data.data());
45for (size_t i = 0; i < size; ++i)
46{
47res[i] = Op::process(
48haystack + haystack_offsets[i - 1],
49haystack_offsets[i] - haystack_offsets[i - 1] - 1,
50needle + needle_offsets[i - 1],
51needle_offsets[i] - needle_offsets[i - 1] - 1);
52}
53}
54
55static void constantVector(
56const String & haystack,
57const ColumnString::Chars & needle_data,
58const ColumnString::Offsets & needle_offsets,
59PaddedPODArray<ResultType> & res)
60{
61const char * haystack_data = haystack.data();
62size_t haystack_size = haystack.size();
63const char * needle = reinterpret_cast<const char *>(needle_data.data());
64size_t size = res.size();
65for (size_t i = 0; i < size; ++i)
66{
67res[i] = Op::process(haystack_data, haystack_size,
68needle + needle_offsets[i - 1], needle_offsets[i] - needle_offsets[i - 1] - 1);
69}
70}
71
72static void vectorConstant(
73const ColumnString::Chars & data,
74const ColumnString::Offsets & offsets,
75const String & needle,
76PaddedPODArray<ResultType> & res)
77{
78constantVector(needle, data, offsets, res);
79}
80
81};
82
83struct ByteHammingDistanceImpl
84{
85using ResultType = UInt64;
86static ResultType process(
87const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
88{
89UInt64 res = 0;
90const char * haystack_end = haystack + haystack_size;
91const char * needle_end = needle + needle_size;
92
93#ifdef __SSE4_2__
94static constexpr auto mode = _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_EACH | _SIDD_NEGATIVE_POLARITY;
95
96const char * haystack_end16 = haystack + haystack_size / 16 * 16;
97const char * needle_end16 = needle + needle_size / 16 * 16;
98
99for (; haystack < haystack_end16 && needle < needle_end16; haystack += 16, needle += 16)
100{
101__m128i s1 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(haystack));
102__m128i s2 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(needle));
103auto result_mask = _mm_cmpestrm(s1, 16, s2, 16, mode);
104const __m128i mask_hi = _mm_unpackhi_epi64(result_mask, result_mask);
105res += _mm_popcnt_u64(_mm_cvtsi128_si64(result_mask)) + _mm_popcnt_u64(_mm_cvtsi128_si64(mask_hi));
106}
107#endif
108for (; haystack != haystack_end && needle != needle_end; ++haystack, ++needle)
109res += *haystack != *needle;
110
111res = res + (haystack_end - haystack) + (needle_end - needle);
112return res;
113}
114};
115
116template <bool is_utf8>
117struct ByteJaccardIndexImpl
118{
119using ResultType = Float64;
120static ResultType process(
121const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
122{
123if (haystack_size == 0 || needle_size == 0)
124return 0;
125
126const char * haystack_end = haystack + haystack_size;
127const char * needle_end = needle + needle_size;
128
129/// For byte strings use plain array as a set
130constexpr size_t max_size = std::numeric_limits<unsigned char>::max() + 1;
131std::array<UInt8, max_size> haystack_set;
132std::array<UInt8, max_size> needle_set;
133
134/// For UTF-8 strings we also use sets of code points greater than max_size
135std::set<UInt32> haystack_utf8_set;
136std::set<UInt32> needle_utf8_set;
137
138haystack_set.fill(0);
139needle_set.fill(0);
140
141while (haystack < haystack_end)
142{
143size_t len = 1;
144if constexpr (is_utf8)
145len = UTF8::seqLength(*haystack);
146
147if (len == 1)
148{
149haystack_set[static_cast<unsigned char>(*haystack)] = 1;
150++haystack;
151}
152else
153{
154auto code_point = UTF8::convertUTF8ToCodePoint(haystack, haystack_end - haystack);
155if (code_point.has_value())
156{
157haystack_utf8_set.insert(code_point.value());
158haystack += len;
159}
160else
161{
162throw Exception(ErrorCodes::BAD_ARGUMENTS, "Illegal UTF-8 sequence, while processing '{}'", StringRef(haystack, haystack_end - haystack));
163}
164}
165}
166
167while (needle < needle_end)
168{
169
170size_t len = 1;
171if constexpr (is_utf8)
172len = UTF8::seqLength(*needle);
173
174if (len == 1)
175{
176needle_set[static_cast<unsigned char>(*needle)] = 1;
177++needle;
178}
179else
180{
181auto code_point = UTF8::convertUTF8ToCodePoint(needle, needle_end - needle);
182if (code_point.has_value())
183{
184needle_utf8_set.insert(code_point.value());
185needle += len;
186}
187else
188{
189throw Exception(ErrorCodes::BAD_ARGUMENTS, "Illegal UTF-8 sequence, while processing '{}'", StringRef(needle, needle_end - needle));
190}
191}
192}
193
194UInt8 intersection = 0;
195UInt8 union_size = 0;
196
197if constexpr (is_utf8)
198{
199auto lit = haystack_utf8_set.begin();
200auto rit = needle_utf8_set.begin();
201while (lit != haystack_utf8_set.end() && rit != needle_utf8_set.end())
202{
203if (*lit == *rit)
204{
205++intersection;
206++lit;
207++rit;
208}
209else if (*lit < *rit)
210++lit;
211else
212++rit;
213}
214union_size = haystack_utf8_set.size() + needle_utf8_set.size() - intersection;
215}
216
217for (size_t i = 0; i < max_size; ++i)
218{
219intersection += haystack_set[i] & needle_set[i];
220union_size += haystack_set[i] | needle_set[i];
221}
222
223return static_cast<ResultType>(intersection) / static_cast<ResultType>(union_size);
224}
225};
226
227static constexpr size_t max_string_size = 1u << 16;
228
229struct ByteEditDistanceImpl
230{
231using ResultType = UInt64;
232
233static ResultType process(
234const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
235{
236if (haystack_size == 0 || needle_size == 0)
237return haystack_size + needle_size;
238
239/// Safety threshold against DoS, since we use two arrays to calculate the distance.
240if (haystack_size > max_string_size || needle_size > max_string_size)
241throw Exception(
242ErrorCodes::TOO_LARGE_STRING_SIZE,
243"The string size is too big for function editDistance, should be at most {}", max_string_size);
244
245PaddedPODArray<ResultType> distances0(haystack_size + 1, 0);
246PaddedPODArray<ResultType> distances1(haystack_size + 1, 0);
247
248ResultType substitution = 0;
249ResultType insertion = 0;
250ResultType deletion = 0;
251
252iota(distances0.data(), haystack_size + 1, ResultType(0));
253
254for (size_t pos_needle = 0; pos_needle < needle_size; ++pos_needle)
255{
256distances1[0] = pos_needle + 1;
257
258for (size_t pos_haystack = 0; pos_haystack < haystack_size; pos_haystack++)
259{
260deletion = distances0[pos_haystack + 1] + 1;
261insertion = distances1[pos_haystack] + 1;
262substitution = distances0[pos_haystack];
263
264if (*(needle + pos_needle) != *(haystack + pos_haystack))
265substitution += 1;
266
267distances1[pos_haystack + 1] = std::min(deletion, std::min(substitution, insertion));
268}
269distances0.swap(distances1);
270}
271
272return distances0[haystack_size];
273}
274};
275
276struct ByteDamerauLevenshteinDistanceImpl
277{
278using ResultType = UInt64;
279
280static ResultType process(
281const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
282{
283/// Safety threshold against DoS
284if (haystack_size > max_string_size || needle_size > max_string_size)
285throw Exception(
286ErrorCodes::TOO_LARGE_STRING_SIZE,
287"The string size is too big for function damerauLevenshteinDistance, should be at most {}", max_string_size);
288
289/// Shortcuts:
290
291if (haystack_size == 0)
292return needle_size;
293
294if (needle_size == 0)
295return haystack_size;
296
297if (haystack_size == needle_size && memcmp(haystack, needle, haystack_size) == 0)
298return 0;
299
300/// Implements the algorithm for optimal string alignment distance from
301/// https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance#Optimal_string_alignment_distance
302
303/// Dynamically allocate memory for the 2D array
304/// Allocating a 2D array, for convenience starts is an array of pointers to the start of the rows.
305std::vector<int> d((needle_size + 1) * (haystack_size + 1));
306std::vector<int *> starts(haystack_size + 1);
307
308/// Setting the pointers in starts to the beginning of (needle_size + 1)-long intervals.
309/// Also initialize the row values based on the mentioned algorithm.
310for (size_t i = 0; i <= haystack_size; ++i)
311{
312starts[i] = d.data() + (needle_size + 1) * i;
313starts[i][0] = static_cast<int>(i);
314}
315
316for (size_t j = 0; j <= needle_size; ++j)
317{
318starts[0][j] = static_cast<int>(j);
319}
320
321for (size_t i = 1; i <= haystack_size; ++i)
322{
323for (size_t j = 1; j <= needle_size; ++j)
324{
325int cost = (haystack[i - 1] == needle[j - 1]) ? 0 : 1;
326starts[i][j] = std::min(starts[i - 1][j] + 1, /// deletion
327std::min(starts[i][j - 1] + 1, /// insertion
328starts[i - 1][j - 1] + cost) /// substitution
329);
330if (i > 1 && j > 1 && haystack[i - 1] == needle[j - 2] && haystack[i - 2] == needle[j - 1])
331starts[i][j] = std::min(starts[i][j], starts[i - 2][j - 2] + 1); /// transposition
332}
333}
334
335return starts[haystack_size][needle_size];
336}
337};
338
339struct ByteJaroSimilarityImpl
340{
341using ResultType = Float64;
342
343static ResultType process(
344const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
345{
346/// Safety threshold against DoS
347if (haystack_size > max_string_size || needle_size > max_string_size)
348throw Exception(
349ErrorCodes::TOO_LARGE_STRING_SIZE,
350"The string size is too big for function jaroSimilarity, should be at most {}", max_string_size);
351
352/// Shortcuts:
353
354if (haystack_size == 0)
355return needle_size;
356
357if (needle_size == 0)
358return haystack_size;
359
360if (haystack_size == needle_size && memcmp(haystack, needle, haystack_size) == 0)
361return 1.0;
362
363const int s1len = static_cast<int>(haystack_size);
364const int s2len = static_cast<int>(needle_size);
365
366/// Window size to search for matches in the other string
367const int max_range = std::max(0, std::max(s1len, s2len) / 2 - 1);
368std::vector<int> s1_matching(s1len, -1);
369std::vector<int> s2_matching(s2len, -1);
370
371/// Calculate matching characters
372size_t matching_characters = 0;
373for (int i = 0; i < s1len; i++)
374{
375/// Matching window
376const int min_index = std::max(i - max_range, 0);
377const int max_index = std::min(i + max_range + 1, s2len);
378for (int j = min_index; j < max_index; j++)
379{
380if (s2_matching[j] == -1 && haystack[i] == needle[j])
381{
382s1_matching[i] = i;
383s2_matching[j] = j;
384matching_characters++;
385break;
386}
387}
388}
389
390if (matching_characters == 0)
391return 0.0;
392
393/// Transpositions (one-way only)
394double transpositions = 0.0;
395for (size_t i = 0, s1i = 0, s2i = 0; i < matching_characters; i++)
396{
397while (s1_matching[s1i] == -1)
398s1i++;
399while (s2_matching[s2i] == -1)
400s2i++;
401if (haystack[s1i] != needle[s2i])
402transpositions += 0.5;
403s1i++;
404s2i++;
405}
406
407double m = static_cast<double>(matching_characters);
408double jaro_similarity = 1.0 / 3.0 * (m / static_cast<double>(s1len)
409+ m / static_cast<double>(s2len)
410+ (m - transpositions) / m);
411return jaro_similarity;
412}
413};
414
415struct ByteJaroWinklerSimilarityImpl
416{
417using ResultType = Float64;
418
419static ResultType process(
420const char * __restrict haystack, size_t haystack_size, const char * __restrict needle, size_t needle_size)
421{
422static constexpr int max_prefix_length = 4;
423static constexpr double scaling_factor = 0.1;
424static constexpr double boost_threshold = 0.7;
425
426/// Safety threshold against DoS
427if (haystack_size > max_string_size || needle_size > max_string_size)
428throw Exception(
429ErrorCodes::TOO_LARGE_STRING_SIZE,
430"The string size is too big for function jaroWinklerSimilarity, should be at most {}", max_string_size);
431
432const int s1len = static_cast<int>(haystack_size);
433const int s2len = static_cast<int>(needle_size);
434
435ResultType jaro_winkler_similarity = ByteJaroSimilarityImpl::process(haystack, haystack_size, needle, needle_size);
436
437if (jaro_winkler_similarity > boost_threshold)
438{
439const int common_length = std::min(max_prefix_length, std::min(s1len, s2len));
440int common_prefix = 0;
441while (common_prefix < common_length && haystack[common_prefix] == needle[common_prefix])
442common_prefix++;
443
444jaro_winkler_similarity += common_prefix * scaling_factor * (1.0 - jaro_winkler_similarity);
445}
446return jaro_winkler_similarity;
447}
448};
449
450struct NameByteHammingDistance
451{
452static constexpr auto name = "byteHammingDistance";
453};
454using FunctionByteHammingDistance = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteHammingDistanceImpl>, NameByteHammingDistance>;
455
456struct NameEditDistance
457{
458static constexpr auto name = "editDistance";
459};
460using FunctionEditDistance = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteEditDistanceImpl>, NameEditDistance>;
461
462struct NameDamerauLevenshteinDistance
463{
464static constexpr auto name = "damerauLevenshteinDistance";
465};
466using FunctionDamerauLevenshteinDistance = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteDamerauLevenshteinDistanceImpl>, NameDamerauLevenshteinDistance>;
467
468struct NameJaccardIndex
469{
470static constexpr auto name = "stringJaccardIndex";
471};
472using FunctionStringJaccardIndex = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteJaccardIndexImpl<false>>, NameJaccardIndex>;
473
474struct NameJaccardIndexUTF8
475{
476static constexpr auto name = "stringJaccardIndexUTF8";
477};
478using FunctionStringJaccardIndexUTF8 = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteJaccardIndexImpl<true>>, NameJaccardIndexUTF8>;
479
480struct NameJaroSimilarity
481{
482static constexpr auto name = "jaroSimilarity";
483};
484using FunctionJaroSimilarity = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteJaroSimilarityImpl>, NameJaroSimilarity>;
485
486struct NameJaroWinklerSimilarity
487{
488static constexpr auto name = "jaroWinklerSimilarity";
489};
490using FunctionJaroWinklerSimilarity = FunctionsStringSimilarity<FunctionStringDistanceImpl<ByteJaroWinklerSimilarityImpl>, NameJaroWinklerSimilarity>;
491
492REGISTER_FUNCTION(StringDistance)
493{
494factory.registerFunction<FunctionByteHammingDistance>(
495FunctionDocumentation{.description = R"(Calculates Hamming distance between two byte-strings.)"});
496factory.registerAlias("mismatches", NameByteHammingDistance::name);
497
498factory.registerFunction<FunctionEditDistance>(
499FunctionDocumentation{.description = R"(Calculates the edit distance between two byte-strings.)"});
500factory.registerAlias("levenshteinDistance", NameEditDistance::name);
501
502factory.registerFunction<FunctionDamerauLevenshteinDistance>(
503FunctionDocumentation{.description = R"(Calculates the Damerau-Levenshtein distance two between two byte-string.)"});
504
505factory.registerFunction<FunctionStringJaccardIndex>(
506FunctionDocumentation{.description = R"(Calculates the Jaccard similarity index between two byte strings.)"});
507factory.registerFunction<FunctionStringJaccardIndexUTF8>(
508FunctionDocumentation{.description = R"(Calculates the Jaccard similarity index between two UTF8 strings.)"});
509
510factory.registerFunction<FunctionJaroSimilarity>(
511FunctionDocumentation{.description = R"(Calculates the Jaro similarity between two byte-string.)"});
512
513factory.registerFunction<FunctionJaroWinklerSimilarity>(
514FunctionDocumentation{.description = R"(Calculates the Jaro-Winkler similarity between two byte-string.)"});
515}
516}
517