google-research
420 строк · 13.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for conducting fuzzy matching for each match type, Price, Date, etc."""
17
18import datetime19import re20from typing import Any, Optional, TypeVar21
22import editdistance23
24# `Entity` defines the generic type for a entity pair: (entity_text, entity_box,
25# entity_segments) where:
26# 1) `entity_text` is a string showing the textual contents of the entity;
27# 2) `entity_box` refers to a tuple of 5 values decribing the page index and the
28# the x, y coordinates of the upper-left and down-right corner of the bounding
29# box.
30# 3) `entity_segments` is a list of two-value tuples, and each tuple indicates
31# the start and end index of this entity in the reading order text. There might
32# be multiple segments in case that this entity involves multiple text spans.
33# e.g., ('Google Research', (0, 1.1, 2.2, 3.3, 4.4), [(0, 15)]).
34#
35# **The `entity_box` and `entity_segments` are used to map the entity back to
36# the image or reading order sequence. They are useful when interpreting the
37# bahavior of the models and converting the benchmark into the required format
38# for the models.
39Entity = TypeVar(40'Entity',41bound=tuple[str, tuple[int, float, float, float, float], list[tuple[int,42int]]])43
44
45class Match:46"""The ancestor class for all specific {Type}Match, e.g., DateMatch.47
48Also include the general fuzzy matching functions, e.g. match_by_substring().
49The {Type}Match will perform fuzzy matching for each match type by calling
50these general matching functions.
51"""
52
53@classmethod54def match(cls, extracted_entity,55labeled_entities):56"""The template for any matching function of a specific {Type}Match.57
58Args:
59extracted_entity: Extraction result, a tuple of two fields: (text, bbox)
60labeled_entities: A list of candidate entities: [(text, bbox), (text,
61bbox), ...], where `text` indicates the textual contents and `bbox`
62locates the entity uniquely in the page. Since the same entity may
63appear multiple times in the doc and the model only needs to extract one
64of them, a list of candidates are provided here. When is only one
65appearance, the list will have one element.
66
67Raises:
68NotImplementedError: This is just a template and should not be called.
69Instead the specific {Type}Match should be called. For example,
70DateMatch.match(('7/1/2022', box), [('07/02/2022', box)])
71"""
72raise NotImplementedError73
74@classmethod75def is_entity(cls, obj):76"""Check whether the input obj is a type of entity."""77if isinstance(obj, tuple) and len(obj) == 3:78# 1) entity text79if not isinstance(obj[0], str):80return False81# 2) entity box82if not isinstance(obj[1], tuple) or len(obj[1]) != 5:83return False84if not isinstance(obj[1][0], int):85return False86for v in obj[1][1:]:87if not isinstance(v, float):88return False89# 3) entity segments90if not isinstance(obj[2], list):91return False92for segment in obj[2]:93if not isinstance(segment, tuple) or len(segment) != 2:94return False95if not isinstance(segment[0], int) or not isinstance(segment[1], int):96return False97return True98
99@classmethod100def remove_redundant_whitespace(cls, string):101r"""Removes the redunant whitespace in the input string.102
1031. Remove the prefix/suffix whitespace.
1042. Replace the continuous whitespace with a single one.
105
106Args:
107string: Entity text string from extractions or annotations.
108
109Returns:
110Remove the prefix/suffix whitespace and replace the continuous whitespace
111with a single one, e.g., ' abc\ndef ghi\t' => 'abc def ghi'.
112"""
113
114substrs = string.strip().split()115proc_str = ' '.join([substr.strip() for substr in substrs])116return proc_str117
118@classmethod119def match_by_alpha_numeric_text(cls, str_a, str_b):120"""Strings match if they are the same after removing all non-alpha-numeric contents.121
122Args:
123str_a: String A
124str_b: String B
125
126Returns:
127If A and B are equivalent after removing all non-alpha-numeric contents.
128e.g. "Xy_Z1 2@3" == "XyZ123"
129
130"""
131proc_a = re.sub(r'[^0-9a-zA-Z]', '', str_a)132proc_b = re.sub(r'[^0-9a-zA-Z]', '', str_b)133return proc_a == proc_b134
135@classmethod136def match_by_non_whitespace_text(cls, str_a, str_b):137r"""Strings match if they are the same after removing all whitespaces.138
139Args:
140str_a: String A
141str_b: String B
142
143Returns:
144If A and B are equivalent after removing all whitespaces.
145e.g. "X y Z 1\t 2 3" == "XyZ123".
146"""
147proc_a = re.sub(r'[\s]', '', str_a)148proc_b = re.sub(r'[\s]', '', str_b)149return proc_a == proc_b150
151@classmethod152def match_by_substring(cls, str_a, str_b):153"""Strings match if the one is the substring of the other.154
155Args:
156str_a: String A
157str_b: String B
158
159Returns:
160A and B are equivalent if the one is the substring of the other.
161e.g. "yZ1" == "XyZ123".
162"""
163proc_a = str_a.strip()164proc_b = str_b.strip()165
166if not proc_a and not proc_b:167return True168if not proc_a or not proc_b:169return False170if proc_a in proc_b or proc_b in proc_a:171return True172return False173
174@classmethod175def match_by_value(cls, str_a, str_b, diff = 0.01):176"""Strings match if the absolute difference is smaller than `diff`.177
178Args:
179str_a: String A
180str_b: String B
181diff: The tolerable difference
182
183Returns:
184First remove any special characters and only keep numbers & `.`, and then
185A and B are equivalent if the absolute difference is smaller than `diff`.
186e.g. "$3.14" == "3.1415926".
187"""
188to_num = lambda x: re.sub(r'[^0-9.]', '', x)189proc_a = to_num(str_a)190proc_b = to_num(str_b)191
192try:193num_a = float(proc_a)194num_b = float(proc_b)195if abs(num_a - num_b) <= diff:196return True197except Exception: # pylint: disable=broad-except198pass199
200return False201
202@classmethod203def match_by_numeric_text(cls, str_a, str_b):204"""Strings match if they are the same after removing all non-numeric contents.205
206Args:
207str_a: String A
208str_b: String B
209
210Returns:
211If A and B are equivalent after removing all non-numeric contents.
212e.g. "Xy_Z1 2@3" == "1xx2yy3zz"
213"""
214proc_a = re.sub(r'[^0-9]', '', str_a)215proc_b = re.sub(r'[^0-9]', '', str_b)216return proc_a == proc_b217
218@classmethod219def match_by_edit_distance(cls,220str_a,221str_b,222threshold = 3):223"""Strings match if the edit_distance is not larger than the threshold.224
225The threshold is 3 as default, since we allow some small mistakes but do not
226want the mistakes change the content meanings.
227
228Args:
229str_a: String A.
230str_b: String B.
231threshold: The minimum distance that can be tolerated.
232
233Returns:
234If the distance is not larger than the threshold, then two strings match.
235"""
236distance = editdistance.eval(str_a, str_b)237return distance <= threshold238
239
240class DateMatch(Match):241"""Match class for match type `date`.242
243Two dates match if they have same year/month/day, OR they have the same
244year, but the month and day are swapped (because different date format:
245MM/DD/YY vs DD/MM/YY).
246"""
247
248@classmethod249def decode_date(cls, date_string):250"""Extracts the year, month, day fields from the date string."""251
252proc_date_string = re.sub(r'[^0-9a-zA-Z/\-,]', '', date_string)253
254def match_pattern(date_string,255pattern):256"""Returns the date dictionary if the date string satisfies the pattern."""257try:258date = datetime.datetime.strptime(date_string, pattern).date()259month = date.month260day = date.day261year = date.year262return {'year': year, 'month': month, 'day': day}263except ValueError:264return None265
266patterns = [267'%m/%d/%y', # 07/01/2022268'%m/%d/%Y', # 07/01/22269'%m/%d', # 07/01, year will be set as 1900 by default270'%b%d/%y', # Jul01/22271'%m-%d-%Y', # 07-01-2022272'%m-%d-%y', # 07-01-22273'%B%d,%Y', # July01,2022274'%Y/%m/%d', # 2022/07/01275]276
277for pattern in patterns:278match_result = match_pattern(proc_date_string, pattern)279if match_result:280return match_result281
282return None283
284@classmethod285def match(cls, extracted_entity,286labeled_entities):287for labeled_entity in labeled_entities:288extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])289labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])290
291extracted_date = cls.decode_date(extracted_text)292labeled_date = cls.decode_date(labeled_text)293
294if not extracted_date or not labeled_date:295if cls.match_by_alpha_numeric_text(extracted_text, labeled_text):296return True297else:298if (extracted_date['year'] == labeled_date['year'] and299extracted_date['month'] == labeled_date['month'] and300extracted_date['day'] == labeled_date['day']):301return True302
303# Some date format is ambiguous. For example, 01/02/2022 can be304# interpreted as Feb 1, 2022 or Jan 2, 2022. It is hard to distinguish305# without contexts. Therefore, if the day and month are swapped, we also306# consider the result as correct.307if (extracted_date['year'] == labeled_date['year'] and308extracted_date['day'] == labeled_date['month'] and309extracted_date['month'] == labeled_date['day']):310return True311
312return False313
314
315class PriceMatch(Match):316"""Match class for match type `price`.317
318First convert the string into float/int numbers and do Match.match_by_value().
319"""
320
321@classmethod322def match(cls, extracted_entity,323labeled_entities):324for labeled_entity in labeled_entities:325extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])326labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])327if cls.match_by_value(extracted_text, labeled_text):328return True329return False330
331
332class AddressMatch(Match):333"""Match class for match type `address`.334
335Two strings match if the edit distance is smaller than a threshold.
336Use Match.match_by_edit_distance().
337"""
338
339@classmethod340def match(cls, extracted_entity,341labeled_entities):342for labeled_entity in labeled_entities:343extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])344labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])345if cls.match_by_edit_distance(extracted_text, labeled_text):346return True347return False348
349
350class NumericalStringMatch(Match):351"""Match class for match type `numerical_string`.352
353Two strings match if they are equivalent after removing all non-numerical
354contents. Use Match.match_by_numeric_text().
355"""
356
357@classmethod358def match(cls, extracted_entity,359labeled_entities):360for labeled_entity in labeled_entities:361extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])362labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])363if cls.match_by_numeric_text(extracted_text, labeled_text):364return True365return False366
367
368class GeneralStringMatch(Match):369"""Match class for match type `general_string`.370
371Two strings match if they are equivalent after removing all non-alpha-numeric
372contents. Use Match.match_by_alpha_numeric_text().
373"""
374
375@classmethod376def match(cls, extracted_entity,377labeled_entities):378for labeled_entity in labeled_entities:379extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])380labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])381if cls.match_by_alpha_numeric_text(extracted_text, labeled_text):382return True383return False384
385
386class NameMatch(Match):387"""Match class for match type `name`.388
389Two strings match if they are equivalent after removing all whitespaces. Use
390Match.match_by_non_whitespace_text().
391"""
392
393@classmethod394def match(cls, extracted_entity,395labeled_entities):396for labeled_entity in labeled_entities:397extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])398labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])399if cls.match_by_non_whitespace_text(extracted_text, labeled_text):400return True401return False402
403
404class DefaultMatch(Match):405"""Default matching class.406
407Used when no matching class is specified.
408Do the strict string matching.
409"""
410
411@classmethod412def match(cls, extracted_entity,413labeled_entities):414print('The default matching function is used.')415for labeled_entity in labeled_entities:416extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])417labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])418if extracted_text == labeled_text:419return True420return False421