google-research

Форк
0
/
match_utils.py 
420 строк · 13.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities for conducting fuzzy matching for each match type, Price, Date, etc."""
17

18
import datetime
19
import re
20
from typing import Any, Optional, TypeVar
21

22
import editdistance
23

24
# `Entity` defines the generic type for a entity pair: (entity_text, entity_box,
25
# entity_segments) where:
26
# 1) `entity_text` is a string showing the textual contents of the entity;
27
# 2) `entity_box` refers to a tuple of 5 values decribing the page index and the
28
# the x, y coordinates of the upper-left and down-right corner of the bounding
29
# box.
30
# 3) `entity_segments` is a list of two-value tuples, and each tuple indicates
31
# the start and end index of this entity in the reading order text. There might
32
# be multiple segments in case that this entity involves multiple text spans.
33
# e.g., ('Google Research', (0, 1.1, 2.2, 3.3, 4.4), [(0, 15)]).
34
#
35
# **The `entity_box` and `entity_segments` are used to map the entity back to
36
# the image or reading order sequence. They are useful when interpreting the
37
# bahavior of the models and converting the benchmark into the required format
38
# for the models.
39
Entity = TypeVar(
40
    'Entity',
41
    bound=tuple[str, tuple[int, float, float, float, float], list[tuple[int,
42
                                                                        int]]])
43

44

45
class Match:
46
  """The ancestor class for all specific {Type}Match, e.g., DateMatch.
47

48
  Also include the general fuzzy matching functions, e.g. match_by_substring().
49
  The {Type}Match will perform fuzzy matching for each match type by calling
50
  these general matching functions.
51
  """
52

53
  @classmethod
54
  def match(cls, extracted_entity,
55
            labeled_entities):
56
    """The template for any matching function of a specific {Type}Match.
57

58
    Args:
59
      extracted_entity: Extraction result, a tuple of two fields: (text, bbox)
60
      labeled_entities: A list of candidate entities: [(text, bbox), (text,
61
        bbox), ...], where `text` indicates the textual contents and `bbox`
62
        locates the entity uniquely in the page. Since the same entity may
63
        appear multiple times in the doc and the model only needs to extract one
64
        of them, a list of candidates are provided here. When is only one
65
        appearance, the list will have one element.
66

67
    Raises:
68
      NotImplementedError: This is just a template and should not be called.
69
      Instead the specific {Type}Match should be called. For example,
70
        DateMatch.match(('7/1/2022', box), [('07/02/2022', box)])
71
    """
72
    raise NotImplementedError
73

74
  @classmethod
75
  def is_entity(cls, obj):
76
    """Check whether the input obj is a type of entity."""
77
    if isinstance(obj, tuple) and len(obj) == 3:
78
      # 1) entity text
79
      if not isinstance(obj[0], str):
80
        return False
81
      # 2) entity box
82
      if not isinstance(obj[1], tuple) or len(obj[1]) != 5:
83
        return False
84
      if not isinstance(obj[1][0], int):
85
        return False
86
      for v in obj[1][1:]:
87
        if not isinstance(v, float):
88
          return False
89
      # 3) entity segments
90
      if not isinstance(obj[2], list):
91
        return False
92
      for segment in obj[2]:
93
        if not isinstance(segment, tuple) or len(segment) != 2:
94
          return False
95
        if not isinstance(segment[0], int) or not isinstance(segment[1], int):
96
          return False
97
    return True
98

99
  @classmethod
100
  def remove_redundant_whitespace(cls, string):
101
    r"""Removes the redunant whitespace in the input string.
102

103
    1. Remove the prefix/suffix whitespace.
104
    2. Replace the continuous whitespace with a single one.
105

106
    Args:
107
      string: Entity text string from extractions or annotations.
108

109
    Returns:
110
      Remove the prefix/suffix whitespace and replace the continuous whitespace
111
      with a single one, e.g., ' abc\ndef   ghi\t' => 'abc def ghi'.
112
    """
113

114
    substrs = string.strip().split()
115
    proc_str = ' '.join([substr.strip() for substr in substrs])
116
    return proc_str
117

118
  @classmethod
119
  def match_by_alpha_numeric_text(cls, str_a, str_b):
120
    """Strings match if they are the same after removing all non-alpha-numeric contents.
121

122
    Args:
123
      str_a: String A
124
      str_b: String B
125

126
    Returns:
127
      If A and B are equivalent after removing all non-alpha-numeric contents.
128
      e.g. "Xy_Z1 2@3" == "XyZ123"
129

130
    """
131
    proc_a = re.sub(r'[^0-9a-zA-Z]', '', str_a)
132
    proc_b = re.sub(r'[^0-9a-zA-Z]', '', str_b)
133
    return proc_a == proc_b
134

135
  @classmethod
136
  def match_by_non_whitespace_text(cls, str_a, str_b):
137
    r"""Strings match if they are the same after removing all whitespaces.
138

139
    Args:
140
      str_a: String A
141
      str_b: String B
142

143
    Returns:
144
      If A and B are equivalent after removing all whitespaces.
145
      e.g. "X y Z 1\t 2 3" == "XyZ123".
146
    """
147
    proc_a = re.sub(r'[\s]', '', str_a)
148
    proc_b = re.sub(r'[\s]', '', str_b)
149
    return proc_a == proc_b
150

151
  @classmethod
152
  def match_by_substring(cls, str_a, str_b):
153
    """Strings match if the one is the substring of the other.
154

155
    Args:
156
      str_a: String A
157
      str_b: String B
158

159
    Returns:
160
      A and B are equivalent if the one is the substring of the other.
161
      e.g. "yZ1" == "XyZ123".
162
    """
163
    proc_a = str_a.strip()
164
    proc_b = str_b.strip()
165

166
    if not proc_a and not proc_b:
167
      return True
168
    if not proc_a or not proc_b:
169
      return False
170
    if proc_a in proc_b or proc_b in proc_a:
171
      return True
172
    return False
173

174
  @classmethod
175
  def match_by_value(cls, str_a, str_b, diff = 0.01):
176
    """Strings match if the absolute difference is smaller than `diff`.
177

178
    Args:
179
      str_a: String A
180
      str_b: String B
181
      diff: The tolerable difference
182

183
    Returns:
184
      First remove any special characters and only keep numbers & `.`, and then
185
      A and B are equivalent if the absolute difference is smaller than `diff`.
186
      e.g. "$3.14" == "3.1415926".
187
    """
188
    to_num = lambda x: re.sub(r'[^0-9.]', '', x)
189
    proc_a = to_num(str_a)
190
    proc_b = to_num(str_b)
191

192
    try:
193
      num_a = float(proc_a)
194
      num_b = float(proc_b)
195
      if abs(num_a - num_b) <= diff:
196
        return True
197
    except Exception:  # pylint: disable=broad-except
198
      pass
199

200
    return False
201

202
  @classmethod
203
  def match_by_numeric_text(cls, str_a, str_b):
204
    """Strings match if they are the same after removing all non-numeric contents.
205

206
    Args:
207
      str_a: String A
208
      str_b: String B
209

210
    Returns:
211
      If A and B are equivalent after removing all non-numeric contents.
212
      e.g. "Xy_Z1 2@3" == "1xx2yy3zz"
213
    """
214
    proc_a = re.sub(r'[^0-9]', '', str_a)
215
    proc_b = re.sub(r'[^0-9]', '', str_b)
216
    return proc_a == proc_b
217

218
  @classmethod
219
  def match_by_edit_distance(cls,
220
                             str_a,
221
                             str_b,
222
                             threshold = 3):
223
    """Strings match if the edit_distance is not larger than the threshold.
224

225
    The threshold is 3 as default, since we allow some small mistakes but do not
226
    want the mistakes change the content meanings.
227

228
    Args:
229
      str_a: String A.
230
      str_b: String B.
231
      threshold: The minimum distance that can be tolerated.
232

233
    Returns:
234
      If the distance is not larger than the threshold, then two strings match.
235
    """
236
    distance = editdistance.eval(str_a, str_b)
237
    return distance <= threshold
238

239

240
class DateMatch(Match):
241
  """Match class for match type `date`.
242

243
  Two dates match if they have same year/month/day, OR they have the same
244
  year, but the month and day are swapped (because different date format:
245
  MM/DD/YY vs DD/MM/YY).
246
  """
247

248
  @classmethod
249
  def decode_date(cls, date_string):
250
    """Extracts the year, month, day fields from the date string."""
251

252
    proc_date_string = re.sub(r'[^0-9a-zA-Z/\-,]', '', date_string)
253

254
    def match_pattern(date_string,
255
                      pattern):
256
      """Returns the date dictionary if the date string satisfies the pattern."""
257
      try:
258
        date = datetime.datetime.strptime(date_string, pattern).date()
259
        month = date.month
260
        day = date.day
261
        year = date.year
262
        return {'year': year, 'month': month, 'day': day}
263
      except ValueError:
264
        return None
265

266
    patterns = [
267
        '%m/%d/%y',  # 07/01/2022
268
        '%m/%d/%Y',  # 07/01/22
269
        '%m/%d',  # 07/01, year will be set as 1900 by default
270
        '%b%d/%y',  # Jul01/22
271
        '%m-%d-%Y',  # 07-01-2022
272
        '%m-%d-%y',  # 07-01-22
273
        '%B%d,%Y',  # July01,2022
274
        '%Y/%m/%d',  # 2022/07/01
275
    ]
276

277
    for pattern in patterns:
278
      match_result = match_pattern(proc_date_string, pattern)
279
      if match_result:
280
        return match_result
281

282
    return None
283

284
  @classmethod
285
  def match(cls, extracted_entity,
286
            labeled_entities):
287
    for labeled_entity in labeled_entities:
288
      extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])
289
      labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])
290

291
      extracted_date = cls.decode_date(extracted_text)
292
      labeled_date = cls.decode_date(labeled_text)
293

294
      if not extracted_date or not labeled_date:
295
        if cls.match_by_alpha_numeric_text(extracted_text, labeled_text):
296
          return True
297
      else:
298
        if (extracted_date['year'] == labeled_date['year'] and
299
            extracted_date['month'] == labeled_date['month'] and
300
            extracted_date['day'] == labeled_date['day']):
301
          return True
302

303
        # Some date format is ambiguous. For example, 01/02/2022 can be
304
        # interpreted as Feb 1, 2022 or Jan 2, 2022. It is hard to distinguish
305
        # without contexts. Therefore, if the day and month are swapped, we also
306
        # consider the result as correct.
307
        if (extracted_date['year'] == labeled_date['year'] and
308
            extracted_date['day'] == labeled_date['month'] and
309
            extracted_date['month'] == labeled_date['day']):
310
          return True
311

312
    return False
313

314

315
class PriceMatch(Match):
316
  """Match class for match type `price`.
317

318
  First convert the string into float/int numbers and do Match.match_by_value().
319
  """
320

321
  @classmethod
322
  def match(cls, extracted_entity,
323
            labeled_entities):
324
    for labeled_entity in labeled_entities:
325
      extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])
326
      labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])
327
      if cls.match_by_value(extracted_text, labeled_text):
328
        return True
329
    return False
330

331

332
class AddressMatch(Match):
333
  """Match class for match type `address`.
334

335
  Two strings match if the edit distance is smaller than a threshold.
336
  Use Match.match_by_edit_distance().
337
  """
338

339
  @classmethod
340
  def match(cls, extracted_entity,
341
            labeled_entities):
342
    for labeled_entity in labeled_entities:
343
      extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])
344
      labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])
345
      if cls.match_by_edit_distance(extracted_text, labeled_text):
346
        return True
347
    return False
348

349

350
class NumericalStringMatch(Match):
351
  """Match class for match type `numerical_string`.
352

353
  Two strings match if they are equivalent after removing all non-numerical
354
  contents. Use Match.match_by_numeric_text().
355
  """
356

357
  @classmethod
358
  def match(cls, extracted_entity,
359
            labeled_entities):
360
    for labeled_entity in labeled_entities:
361
      extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])
362
      labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])
363
      if cls.match_by_numeric_text(extracted_text, labeled_text):
364
        return True
365
    return False
366

367

368
class GeneralStringMatch(Match):
369
  """Match class for match type `general_string`.
370

371
  Two strings match if they are equivalent after removing all non-alpha-numeric
372
  contents. Use Match.match_by_alpha_numeric_text().
373
  """
374

375
  @classmethod
376
  def match(cls, extracted_entity,
377
            labeled_entities):
378
    for labeled_entity in labeled_entities:
379
      extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])
380
      labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])
381
      if cls.match_by_alpha_numeric_text(extracted_text, labeled_text):
382
        return True
383
    return False
384

385

386
class NameMatch(Match):
387
  """Match class for match type `name`.
388

389
  Two strings match if they are equivalent after removing all whitespaces. Use
390
  Match.match_by_non_whitespace_text().
391
  """
392

393
  @classmethod
394
  def match(cls, extracted_entity,
395
            labeled_entities):
396
    for labeled_entity in labeled_entities:
397
      extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])
398
      labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])
399
      if cls.match_by_non_whitespace_text(extracted_text, labeled_text):
400
        return True
401
    return False
402

403

404
class DefaultMatch(Match):
405
  """Default matching class.
406

407
  Used when no matching class is specified.
408
  Do the strict string matching.
409
  """
410

411
  @classmethod
412
  def match(cls, extracted_entity,
413
            labeled_entities):
414
    print('The default matching function is used.')
415
    for labeled_entity in labeled_entities:
416
      extracted_text = cls.remove_redundant_whitespace(extracted_entity[0])
417
      labeled_text = cls.remove_redundant_whitespace(labeled_entity[0])
418
      if extracted_text == labeled_text:
419
        return True
420
    return False
421

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.