ClickHouse
766 строк · 27.0 Кб
1#include <AggregateFunctions/Helpers.h>
2#include <AggregateFunctions/AggregateFunctionFactory.h>
3
4#include <DataTypes/DataTypeDate.h>
5#include <DataTypes/DataTypeDateTime.h>
6
7#include <AggregateFunctions/IAggregateFunction.h>
8#include <DataTypes/DataTypesNumber.h>
9#include <Columns/ColumnsNumber.h>
10#include <Common/assert_cast.h>
11#include <IO/ReadHelpers.h>
12#include <IO/WriteHelpers.h>
13#include <base/range.h>
14
15#include <bitset>
16#include <stack>
17
18
19namespace DB
20{
21
22struct Settings;
23
24namespace ErrorCodes
25{
26extern const int ILLEGAL_TYPE_OF_ARGUMENT;
27extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
28extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
29extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
30extern const int TOO_SLOW;
31extern const int SYNTAX_ERROR;
32extern const int BAD_ARGUMENTS;
33extern const int LOGICAL_ERROR;
34}
35
36namespace
37{
38
39/// helper type for comparing `std::pair`s using solely the .first member
40template <template <typename> class Comparator>
41struct ComparePairFirst final
42{
43template <typename T1, typename T2>
44bool operator()(const std::pair<T1, T2> & lhs, const std::pair<T1, T2> & rhs) const
45{
46return Comparator<T1>{}(lhs.first, rhs.first);
47}
48};
49
50constexpr size_t max_events = 32;
51
52template <typename T>
53struct AggregateFunctionSequenceMatchData final
54{
55using Timestamp = T;
56using Events = std::bitset<max_events>;
57using TimestampEvents = std::pair<Timestamp, Events>;
58using Comparator = ComparePairFirst<std::less>;
59
60bool sorted = true;
61PODArrayWithStackMemory<TimestampEvents, 64> events_list;
62/// sequenceMatch conditions met at least once in events_list
63Events conditions_met;
64
65void add(const Timestamp timestamp, const Events & events)
66{
67/// store information exclusively for rows with at least one event
68if (events.any())
69{
70events_list.emplace_back(timestamp, events);
71sorted = false;
72conditions_met |= events;
73}
74}
75
76void merge(const AggregateFunctionSequenceMatchData & other)
77{
78if (other.events_list.empty())
79return;
80
81events_list.insert(std::begin(other.events_list), std::end(other.events_list));
82sorted = false;
83conditions_met |= other.conditions_met;
84}
85
86void sort()
87{
88if (sorted)
89return;
90
91::sort(std::begin(events_list), std::end(events_list), Comparator{});
92sorted = true;
93}
94
95void serialize(WriteBuffer & buf) const
96{
97writeBinary(sorted, buf);
98writeBinary(events_list.size(), buf);
99
100for (const auto & events : events_list)
101{
102writeBinary(events.first, buf);
103writeBinary(events.second.to_ulong(), buf);
104}
105}
106
107void deserialize(ReadBuffer & buf)
108{
109readBinary(sorted, buf);
110
111size_t size;
112readBinary(size, buf);
113
114/// If we lose these flags, functionality is broken
115/// If we serialize/deserialize these flags, we have compatibility issues
116/// If we set these flags to 1, we have a minor performance penalty, which seems acceptable
117conditions_met.set();
118
119events_list.clear();
120events_list.reserve(size);
121
122for (size_t i = 0; i < size; ++i)
123{
124Timestamp timestamp;
125readBinary(timestamp, buf);
126
127UInt64 events;
128readBinary(events, buf);
129
130events_list.emplace_back(timestamp, Events{events});
131}
132}
133};
134
135
136/// Max number of iterations to match the pattern against a sequence, exception thrown when exceeded
137constexpr auto sequence_match_max_iterations = 1000000;
138
139
140template <typename T, typename Data, typename Derived>
141class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
142{
143public:
144AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_, const DataTypePtr & result_type_)
145: IAggregateFunctionDataHelper<Data, Derived>(arguments, params, result_type_)
146, pattern(pattern_)
147{
148arg_count = arguments.size();
149parsePattern();
150}
151
152void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override
153{
154const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
155
156typename Data::Events events;
157for (const auto i : collections::range(1, arg_count))
158{
159const auto event = assert_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
160events.set(i - 1, event);
161}
162
163this->data(place).add(timestamp, events);
164}
165
166void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
167{
168this->data(place).merge(this->data(rhs));
169}
170
171void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
172{
173this->data(place).serialize(buf);
174}
175
176void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
177{
178this->data(place).deserialize(buf);
179}
180
181bool haveSameStateRepresentationImpl(const IAggregateFunction & rhs) const override
182{
183return this->getName() == rhs.getName() && this->haveEqualArgumentTypes(rhs);
184}
185
186private:
187enum class PatternActionType
188{
189SpecificEvent,
190AnyEvent,
191KleeneStar,
192TimeLessOrEqual,
193TimeLess,
194TimeGreaterOrEqual,
195TimeGreater,
196TimeEqual
197};
198
199struct PatternAction final
200{
201PatternActionType type;
202std::uint64_t extra;
203
204PatternAction() = default;
205explicit PatternAction(const PatternActionType type_, const std::uint64_t extra_ = 0) : type{type_}, extra{extra_} {}
206};
207
208using PatternActions = PODArrayWithStackMemory<PatternAction, 64>;
209
210Derived & derived() { return static_cast<Derived &>(*this); }
211
212void parsePattern()
213{
214actions.clear();
215actions.emplace_back(PatternActionType::KleeneStar);
216
217dfa_states.clear();
218dfa_states.emplace_back(true);
219
220pattern_has_time = false;
221
222const char * pos = pattern.data();
223const char * begin = pos;
224const char * end = pos + pattern.size();
225
226auto throw_exception = [&](const std::string & msg)
227{
228throw Exception(ErrorCodes::SYNTAX_ERROR, "{} '{}' at position {}", msg, std::string(pos, end), toString(pos - begin));
229};
230
231auto match = [&pos, end](const char * str) mutable
232{
233size_t length = strlen(str);
234if (pos + length <= end && 0 == memcmp(pos, str, length))
235{
236pos += length;
237return true;
238}
239return false;
240};
241
242while (pos < end)
243{
244if (match("(?"))
245{
246if (match("t"))
247{
248PatternActionType type;
249
250if (match("<="))
251type = PatternActionType::TimeLessOrEqual;
252else if (match("<"))
253type = PatternActionType::TimeLess;
254else if (match(">="))
255type = PatternActionType::TimeGreaterOrEqual;
256else if (match(">"))
257type = PatternActionType::TimeGreater;
258else if (match("=="))
259type = PatternActionType::TimeEqual;
260else
261throw_exception("Unknown time condition");
262
263UInt64 duration = 0;
264const auto * prev_pos = pos;
265pos = tryReadIntText(duration, pos, end);
266if (pos == prev_pos)
267throw_exception("Could not parse number");
268
269if (actions.back().type != PatternActionType::SpecificEvent &&
270actions.back().type != PatternActionType::AnyEvent &&
271actions.back().type != PatternActionType::KleeneStar)
272throw Exception(ErrorCodes::BAD_ARGUMENTS, "Temporal condition should be preceded by an event condition");
273
274pattern_has_time = true;
275actions.emplace_back(type, duration);
276}
277else
278{
279UInt64 event_number = 0;
280const auto * prev_pos = pos;
281pos = tryReadIntText(event_number, pos, end);
282if (pos == prev_pos)
283throw_exception("Could not parse number");
284
285if (event_number > arg_count - 1)
286throw Exception(ErrorCodes::BAD_ARGUMENTS, "Event number {} is out of range", event_number);
287
288actions.emplace_back(PatternActionType::SpecificEvent, event_number - 1);
289dfa_states.back().transition = DFATransition::SpecificEvent;
290dfa_states.back().event = static_cast<uint32_t>(event_number - 1);
291dfa_states.emplace_back();
292conditions_in_pattern.set(event_number - 1);
293}
294
295if (!match(")"))
296throw_exception("Expected closing parenthesis, found");
297
298}
299else if (match(".*"))
300{
301actions.emplace_back(PatternActionType::KleeneStar);
302dfa_states.back().has_kleene = true;
303}
304else if (match("."))
305{
306actions.emplace_back(PatternActionType::AnyEvent);
307dfa_states.back().transition = DFATransition::AnyEvent;
308dfa_states.emplace_back();
309}
310else
311throw_exception("Could not parse pattern, unexpected starting symbol");
312}
313}
314
315protected:
316/// Uses a DFA based approach in order to better handle patterns without
317/// time assertions.
318///
319/// NOTE: This implementation relies on the assumption that the pattern is *small*.
320///
321/// This algorithm performs in O(mn) (with m the number of DFA states and N the number
322/// of events) with a memory consumption and memory allocations in O(m). It means that
323/// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
324template <typename EventEntry>
325bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
326{
327using ActiveStates = std::vector<bool>;
328
329/// Those two vectors keep track of which states should be considered for the current
330/// event as well as the states which should be considered for the next event.
331ActiveStates active_states(dfa_states.size(), false);
332ActiveStates next_active_states(dfa_states.size(), false);
333active_states[0] = true;
334
335/// Keeps track of dead-ends in order not to iterate over all the events to realize that
336/// the match failed.
337size_t n_active = 1;
338
339for (/* empty */; events_it != events_end && n_active > 0 && !active_states.back(); ++events_it)
340{
341n_active = 0;
342next_active_states.assign(dfa_states.size(), false);
343
344for (size_t state = 0; state < dfa_states.size(); ++state)
345{
346if (!active_states[state])
347{
348continue;
349}
350
351switch (dfa_states[state].transition)
352{
353case DFATransition::None:
354break;
355case DFATransition::AnyEvent:
356next_active_states[state + 1] = true;
357++n_active;
358break;
359case DFATransition::SpecificEvent:
360if (events_it->second.test(dfa_states[state].event))
361{
362next_active_states[state + 1] = true;
363++n_active;
364}
365break;
366}
367
368if (dfa_states[state].has_kleene)
369{
370next_active_states[state] = true;
371++n_active;
372}
373}
374swap(active_states, next_active_states);
375}
376
377return active_states.back();
378}
379
380template <typename EventEntry>
381bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
382{
383const auto action_begin = std::begin(actions);
384const auto action_end = std::end(actions);
385auto action_it = action_begin;
386
387const auto events_begin = events_it;
388auto base_it = events_it;
389
390/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
391using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
392std::stack<backtrack_info> back_stack;
393
394/// backtrack if possible
395const auto do_backtrack = [&]
396{
397while (!back_stack.empty())
398{
399auto & top = back_stack.top();
400
401action_it = std::get<0>(top);
402events_it = std::next(std::get<1>(top));
403base_it = std::get<2>(top);
404
405back_stack.pop();
406
407if (events_it != events_end)
408return true;
409}
410
411return false;
412};
413
414size_t i = 0;
415while (action_it != action_end && events_it != events_end)
416{
417if (action_it->type == PatternActionType::SpecificEvent)
418{
419if (events_it->second.test(action_it->extra))
420{
421/// move to the next action and events
422base_it = events_it;
423++action_it, ++events_it;
424}
425else if (!do_backtrack())
426/// backtracking failed, bail out
427break;
428}
429else if (action_it->type == PatternActionType::AnyEvent)
430{
431base_it = events_it;
432++action_it, ++events_it;
433}
434else if (action_it->type == PatternActionType::KleeneStar)
435{
436back_stack.emplace(action_it, events_it, base_it);
437base_it = events_it;
438++action_it;
439}
440else if (action_it->type == PatternActionType::TimeLessOrEqual)
441{
442if (events_it->first <= base_it->first + action_it->extra)
443{
444/// condition satisfied, move onto next action
445back_stack.emplace(action_it, events_it, base_it);
446base_it = events_it;
447++action_it;
448}
449else if (!do_backtrack())
450break;
451}
452else if (action_it->type == PatternActionType::TimeLess)
453{
454if (events_it->first < base_it->first + action_it->extra)
455{
456back_stack.emplace(action_it, events_it, base_it);
457base_it = events_it;
458++action_it;
459}
460else if (!do_backtrack())
461break;
462}
463else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
464{
465if (events_it->first >= base_it->first + action_it->extra)
466{
467back_stack.emplace(action_it, events_it, base_it);
468base_it = events_it;
469++action_it;
470}
471else if (++events_it == events_end && !do_backtrack())
472break;
473}
474else if (action_it->type == PatternActionType::TimeGreater)
475{
476if (events_it->first > base_it->first + action_it->extra)
477{
478back_stack.emplace(action_it, events_it, base_it);
479base_it = events_it;
480++action_it;
481}
482else if (++events_it == events_end && !do_backtrack())
483break;
484}
485else if (action_it->type == PatternActionType::TimeEqual)
486{
487if (events_it->first == base_it->first + action_it->extra)
488{
489back_stack.emplace(action_it, events_it, base_it);
490base_it = events_it;
491++action_it;
492}
493else if (++events_it == events_end && !do_backtrack())
494break;
495}
496else
497throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown PatternActionType");
498
499if (++i > sequence_match_max_iterations)
500throw Exception(ErrorCodes::TOO_SLOW, "Pattern application proves too difficult, exceeding max iterations ({})",
501sequence_match_max_iterations);
502}
503
504/// if there are some actions remaining
505if (action_it != action_end)
506{
507/// match multiple empty strings at end
508while (action_it->type == PatternActionType::KleeneStar ||
509action_it->type == PatternActionType::TimeLessOrEqual ||
510action_it->type == PatternActionType::TimeLess ||
511(action_it->type == PatternActionType::TimeGreaterOrEqual && action_it->extra == 0))
512++action_it;
513}
514
515if (events_it == events_begin)
516++events_it;
517
518return action_it == action_end;
519}
520
521/// Splits the pattern into deterministic parts separated by non-deterministic fragments
522/// (time constraints and Kleene stars), and tries to match the deterministic parts in their specified order,
523/// ignoring the non-deterministic fragments.
524/// This function can quickly check that a full match is not possible if some deterministic fragment is missing.
525template <typename EventEntry>
526bool couldMatchDeterministicParts(const EventEntry events_begin, const EventEntry events_end, bool limit_iterations = true) const
527{
528size_t events_processed = 0;
529auto events_it = events_begin;
530
531const auto actions_end = std::end(actions);
532auto actions_it = std::begin(actions);
533auto det_part_begin = actions_it;
534
535auto match_deterministic_part = [&events_it, events_end, &events_processed, det_part_begin, actions_it, limit_iterations]()
536{
537auto events_it_init = events_it;
538auto det_part_it = det_part_begin;
539
540while (det_part_it != actions_it && events_it != events_end)
541{
542/// matching any event
543if (det_part_it->type == PatternActionType::AnyEvent)
544++events_it, ++det_part_it;
545
546/// matching specific event
547else
548{
549if (events_it->second.test(det_part_it->extra))
550++events_it, ++det_part_it;
551
552/// abandon current matching, try to match the deterministic fragment further in the list
553else
554{
555events_it = ++events_it_init;
556det_part_it = det_part_begin;
557}
558}
559
560if (limit_iterations && ++events_processed > sequence_match_max_iterations)
561throw Exception(ErrorCodes::TOO_SLOW, "Pattern application proves too difficult, exceeding max iterations ({})",
562sequence_match_max_iterations);
563}
564
565return det_part_it == actions_it;
566};
567
568for (; actions_it != actions_end; ++actions_it)
569if (actions_it->type != PatternActionType::SpecificEvent && actions_it->type != PatternActionType::AnyEvent)
570{
571if (!match_deterministic_part())
572return false;
573det_part_begin = std::next(actions_it);
574}
575
576return match_deterministic_part();
577}
578
579private:
580enum class DFATransition : char
581{
582/// .-------.
583/// | |
584/// `-------'
585None,
586/// .-------. (?[0-9])
587/// | | ----------
588/// `-------'
589SpecificEvent,
590/// .-------. .
591/// | | ----------
592/// `-------'
593AnyEvent,
594};
595
596struct DFAState
597{
598explicit DFAState(bool has_kleene_ = false)
599: has_kleene{has_kleene_}, event{0}, transition{DFATransition::None}
600{}
601
602/// .-------.
603/// | | - - -
604/// `-------'
605/// |_^
606bool has_kleene;
607/// In the case of a state transitions with a `SpecificEvent`,
608/// `event` contains the value of the event.
609uint32_t event;
610/// The kind of transition out of this state.
611DFATransition transition;
612};
613
614using DFAStates = std::vector<DFAState>;
615
616protected:
617/// `True` if the parsed pattern contains time assertions (?t...), `false` otherwise.
618bool pattern_has_time;
619/// sequenceMatch conditions met at least once in the pattern
620std::bitset<max_events> conditions_in_pattern;
621
622private:
623std::string pattern;
624size_t arg_count;
625PatternActions actions;
626
627DFAStates dfa_states;
628};
629
630template <typename T, typename Data>
631class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
632{
633public:
634AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern_)
635: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt8>()) {}
636
637using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
638
639String getName() const override { return "sequenceMatch"; }
640
641bool allocatesMemoryInArena() const override { return false; }
642
643void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
644{
645auto & output = assert_cast<ColumnUInt8 &>(to).getData();
646if ((this->conditions_in_pattern & this->data(place).conditions_met) != this->conditions_in_pattern)
647{
648output.push_back(false);
649return;
650}
651this->data(place).sort();
652
653const auto & data_ref = this->data(place);
654
655const auto events_begin = std::begin(data_ref.events_list);
656const auto events_end = std::end(data_ref.events_list);
657auto events_it = events_begin;
658
659bool match = (this->pattern_has_time ?
660(this->couldMatchDeterministicParts(events_begin, events_end) && this->backtrackingMatch(events_it, events_end)) :
661this->dfaMatch(events_it, events_end));
662output.push_back(match);
663}
664};
665
666template <typename T, typename Data>
667class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
668{
669public:
670AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern_)
671: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_, std::make_shared<DataTypeUInt64>()) {}
672
673using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
674
675String getName() const override { return "sequenceCount"; }
676
677bool allocatesMemoryInArena() const override { return false; }
678
679void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
680{
681auto & output = assert_cast<ColumnUInt64 &>(to).getData();
682if ((this->conditions_in_pattern & this->data(place).conditions_met) != this->conditions_in_pattern)
683{
684output.push_back(0);
685return;
686}
687this->data(place).sort();
688output.push_back(count(place));
689}
690
691private:
692UInt64 count(ConstAggregateDataPtr __restrict place) const
693{
694const auto & data_ref = this->data(place);
695
696const auto events_begin = std::begin(data_ref.events_list);
697const auto events_end = std::end(data_ref.events_list);
698auto events_it = events_begin;
699
700size_t count = 0;
701// check if there is a chance of matching the sequence at least once
702if (this->couldMatchDeterministicParts(events_begin, events_end))
703{
704while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
705++count;
706}
707
708return count;
709}
710};
711
712
713template <template <typename, typename> typename AggregateFunction, template <typename> typename Data>
714AggregateFunctionPtr createAggregateFunctionSequenceBase(
715const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
716{
717if (params.size() != 1)
718throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires exactly one parameter.",
719name);
720
721const auto arg_count = argument_types.size();
722
723if (arg_count < 3)
724throw Exception(ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Aggregate function {} requires at least 3 arguments.",
725name);
726
727if (arg_count - 1 > max_events)
728throw Exception(ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION, "Aggregate function {} supports up to {} event arguments.", name, max_events);
729
730const auto * time_arg = argument_types.front().get();
731
732for (const auto i : collections::range(1, arg_count))
733{
734const auto * cond_arg = argument_types[i].get();
735if (!isUInt8(cond_arg))
736throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
737"Illegal type {} of argument {} of aggregate function {}, must be UInt8",
738cond_arg->getName(), toString(i + 1), name);
739}
740
741String pattern = params.front().safeGet<std::string>();
742
743AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunction, Data>(*argument_types[0], argument_types, params, pattern));
744if (res)
745return res;
746
747WhichDataType which(argument_types.front().get());
748if (which.isDateTime())
749return std::make_shared<AggregateFunction<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types, params, pattern);
750else if (which.isDate())
751return std::make_shared<AggregateFunction<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types, params, pattern);
752
753throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
754"Illegal type {} of first argument of aggregate function {}, must be DateTime",
755time_arg->getName(), name);
756}
757
758}
759
760void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory)
761{
762factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceBase<AggregateFunctionSequenceMatch, AggregateFunctionSequenceMatchData>);
763factory.registerFunction("sequenceCount", createAggregateFunctionSequenceBase<AggregateFunctionSequenceCount, AggregateFunctionSequenceMatchData>);
764}
765
766}
767