google-research

Форк
0
/
lexical_accuracy.py 
430 строк · 15.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Library to compute simple region-based lexical accuracy metric."""
17

18
import collections
19
import csv
20
import dataclasses
21
import enum
22
import pathlib
23
import re
24
from typing import IO, Iterable, Mapping, Optional, Sequence, Tuple, Union
25

26
from absl import logging
27
from immutabledict import immutabledict
28

29

30
# The following maps define the terms of interest, indexed on the original
31
# English seed term.
32

33
# Format: English: (Simp-CN, Simp-TW, Trad-TW, Trad-CN)
34
# Spaces are stripped from the Chinese corpus before matching these terms.
35
_CHINESE_TERMS = immutabledict({
36
    "Pineapple": ("菠萝", "凤梨", "鳳梨", "菠蘿"),
37
    "Computer mouse": ("鼠标", "滑鼠", "滑鼠", "鼠標"),
38
    # Original source had CN:牛油果, but translator used 鳄梨.
39
    "Avocado": ("鳄梨", "酪梨", "酪梨", "鱷梨"),
40
    "Band-Aid": ("创可贴", "OK绷", "OK繃", "創可貼"),
41
    "Blog": ("博客", "部落格", "部落格", "博客"),
42
    "New Zealand": ("新西兰", "纽西兰", "紐西蘭", "新西蘭"),
43
    "Printer (computing)": ("打印机", "印表机", "印表機", "打印機"),
44
    # Original source has TW:月臺, but translator used 月台.
45
    "Railway platform": ("站台", "月台", "月台", "站台"),
46
    "Roller coaster": ("过山车", "云霄飞车", "雲霄飛車", "過山車"),
47
    "Salmon": ("三文鱼", "鲑鱼", "鮭魚", "三文魚"),
48
    "Shampoo": ("洗发水", "洗发精", "洗髮精", "洗髮水"),
49
    # From Wikipedia page "Software testing"
50
    "Software": ("软件", "软体", "軟體", "軟件"),
51
    "Sydney": ("悉尼", "雪梨", "雪梨", "悉尼"),
52

53
    # The following two are excluded because they underpin the first 100
54
    # lexical exemplars used for priming the models.
55
    ## "Flip-flops": ("人字拖", "夹脚拖", "夾腳拖", "人字拖"),
56
    ## "Paper clip": ("回形针", "回纹针", "迴紋針", "回形針"),
57
})
58

59
# Portuguese terms.
60
# Format: English: (BR, PT)
61
# The Portuguese corpus is lowercased before matching these terms.
62
_PORTUGUESE_TERMS = immutabledict({
63
    "Bathroom": ("banheiro", "casa de banho"),
64
    # Original source had "pequeno almoço" but translator used "pequeno-almoço".
65
    "Breakfast": ("café da manhã", "pequeno-almoço"),
66
    "Bus": ("ônibus", "autocarro"),
67
    "Cup": ("xícara", "chávena"),
68
    "Computer mouse": ("mouse", "rato"),
69
    "Drivers license": ("carteira de motorista", "carta de condução"),
70
    # From Wikipedia page "Ice cream sandwich"
71
    "Ice cream": ("sorvete", "gelado"),
72
    "Juice": ("suco", "sumo"),
73
    "Mobile phone": ("celular", "telemóvel"),
74
    "Pedestrian": ("pedestre", "peão"),
75
    # From Wikipedia page "Pickpocketing"
76
    "Pickpocket": ("batedor de carteiras", "carteirista"),
77
    "Pineapple": ("abacaxi", "ananás"),
78
    "Refrigerator": ("geladeira", "frigorífico"),
79
    "Suit": ("terno", "fato"),
80
    "Train": ("trem", "comboio"),
81
    "Video game": ("videogame", "videojogos"),
82

83
    # Terms updated after original selection.
84

85
    # For BR, replaced "menina" (common in speech) with "garota" (common in
86
    # writing, matching the human translators.
87
    "Girl": ("garota", "rapariga"),
88

89
    # Replace original "Computer monitor": ("tela de computador", "ecrã") with
90
    # the observed use for just screen:
91
    "Screen": ("tela", "ecrã"),
92

93
    # Terms excluded.
94

95
    # The following three are excluded because they underpin the first 100
96
    # lexical exemplars used for priming the models.
97
    ## "Gym": ("academia", "ginásio"),
98
    ## "Stapler": ("grampeador", "agrafador"),
99
    ## "Nightgown": ("camisola", "camisa de noite"),
100

101
    # The following are excluded for other reasons:
102

103
    # BR translator primarily used 'comissário de bordo' and hardly ever
104
    # 'aeromoça'. PT translator used 'comissários/assistentes de bordo' or just
105
    # 'assistentes de bordo' Excluding the term as low-signal for now.
106
    ## "Flight attendant": ("aeromoça", "comissário ao bordo"),
107

108
    # Both regions' translators consistently used "presunto", so the term has
109
    # low signal.
110
    ## "Ham": ("presunto", "fiambre"),
111
})
112

113
_StrOrPurePath = Union[str, pathlib.PurePath]
114

115

116
def _open_file(path: _StrOrPurePath, mode: str = "r") -> IO[str]:
117
  return open(path, mode)  # pylint: disable=unreachable
118

119

120
@dataclasses.dataclass(frozen=True)
121
class TermCount:
122
  matched: int
123
  mismatched: int
124

125

126
TermCounts = Mapping[str, TermCount]
127

128

129
class ZhScript(enum.Enum):
130
  SIMPLIFIED = 1
131
  TRADITIONAL = 2
132

133

134
def _count_term_hits_with_regex(text: str, term: str) -> int:
135
  # Avoids overtriggering when term happens to be a substring in unrelated
136
  # words.
137
  pattern = r"\b" + term + r"\b"
138
  return len(re.findall(pattern, text))
139

140

141
def _score_terms(
142
    corpus: Sequence[str],
143
    matched_terms: Iterable[str],
144
    mismatched_terms: Iterable[str],
145
    per_example_cap: int = 1,
146
    use_regex: bool = True,
147
) -> TermCount:
148
  """Scores term by counting number of non-overlapping substring occurrences."""
149
  matched_total = 0
150
  mismatched_total = 0
151

152
  def _count(sentence: str, term: str) -> int:
153
    if use_regex:
154
      return _count_term_hits_with_regex(sentence, term)
155
    else:
156
      return sentence.count(term)
157

158
  for sentence in corpus:
159
    matched_term_counts = [
160
        _count(sentence, matched_term) for matched_term in matched_terms
161
    ]
162
    matched_count = min(sum(matched_term_counts), per_example_cap)
163
    matched_total += matched_count
164
    mismatched_term_counts = [
165
        _count(sentence, mismatched_term)
166
        for mismatched_term in mismatched_terms
167
    ]
168
    mismatched_count = min(sum(mismatched_term_counts), per_example_cap)
169
    mismatched_total += mismatched_count
170

171
    for matched_term, matched_term_count in zip(matched_terms,
172
                                                matched_term_counts):
173
      if matched_term_count > 0:
174
        logging.debug("Hit (match) '%s': %s", matched_term, sentence)
175
    for mismatched_term, mismatched_term_count in zip(mismatched_terms,
176
                                                      mismatched_term_counts):
177
      if mismatched_term_count > 0:
178
        logging.debug("Hit (mismatch) '%s': %s", mismatched_term, sentence)
179

180
  return TermCount(matched=matched_total, mismatched=mismatched_total)
181

182

183
def score_corpus(
184
    corpus: Iterable[str],
185
    terms: Mapping[str, Tuple[Iterable[str], Iterable[str]]],
186
    use_regex: bool = True,
187
) -> TermCounts:
188
  r"""Counts occurrences of matching and non-matching terms in corpus.
189

190
  Args:
191
    corpus: Text to evaluate, in the target language, as an iterable over e.g.
192
      sentences.
193
    terms: Map from a source language term to its target language translations
194
      (x, y) such that x is matched to the regional variant of the corpus, and y
195
      is mis-matched to the regional variant of the corpus. x and y are
196
      iterables that may contain alternative orthographic realizations of the
197
      term variant.
198
    use_regex: Whether to do term matching using a regex that requires word
199
      boundaries (\b) around the search term, otherwise uses str.count(). Set to
200
      False if the corpus is in a non-spaced language like Chinese.
201

202
  Returns:
203
    Map from each source language term to a TermCount, recording the number
204
    of occurrences of the matched and mismatched terms in the corpus.
205
  """
206
  corpus = list(corpus)
207
  return {
208
      source_word:
209
      _score_terms(corpus, matching, mismatching, use_regex=use_regex)
210
      for source_word, (matching, mismatching) in terms.items()
211
  }
212

213

214
def score_pt(corpus_br: Iterable[str],
215
             corpus_pt: Iterable[str]) -> Tuple[TermCounts, TermCounts]:
216
  """Calls score_corpus using the hardcoded list of Portuguese terms."""
217
  # _PORTUGUESE_TERMS is already organized as (match, mismatch)-pairs for pt-BR,
218
  # but needs to be converted from strings to lists of strings
219
  counts_br = score_corpus(
220
      corpus_br,
221
      terms={
222
          word: ([br], [pt]) for word, (br, pt) in _PORTUGUESE_TERMS.items()
223
      })
224

225
  # _PORTUGUESE_TERMS must be reorganized as (match, mismatch)-pairs for pt-PT.
226
  counts_pt = score_corpus(
227
      corpus_pt,
228
      terms={
229
          word: ([pt], [br]) for word, (br, pt) in _PORTUGUESE_TERMS.items()
230
      })
231

232
  return counts_br, counts_pt
233

234

235
def score_zh(corpus_cn: Iterable[str], corpus_tw: Iterable[str],
236
             script_cn: Optional[ZhScript],
237
             script_tw: Optional[ZhScript]) -> Tuple[TermCounts, TermCounts]:
238
  """Calls score_corpus using the hardcoded list of Chinese terms."""
239
  # Reformat the Chinese term dictionary into (match, mismatch)-pairs for each
240
  # corpus, based on the script of the corpus.
241
  terms_for_cn = {}
242
  for word, (simp_cn, simp_tw, trad_tw, trad_cn) in _CHINESE_TERMS.items():
243
    if script_cn == ZhScript.SIMPLIFIED:
244
      terms_for_cn[word] = ([simp_cn], [simp_tw])
245
    elif script_cn == ZhScript.TRADITIONAL:
246
      terms_for_cn[word] = ([trad_cn], [trad_tw])
247
    else:
248
      terms_for_cn[word] = ([simp_cn, trad_cn], [simp_tw, trad_tw])
249
  counts_cn = score_corpus(corpus_cn, terms=terms_for_cn, use_regex=False)
250

251
  terms_for_tw = {}
252
  for word, (simp_cn, simp_tw, trad_tw, trad_cn) in _CHINESE_TERMS.items():
253
    if script_tw == ZhScript.SIMPLIFIED:
254
      terms_for_tw[word] = ([simp_tw], [simp_cn])
255
    elif script_tw == ZhScript.TRADITIONAL:
256
      terms_for_tw[word] = ([trad_tw], [trad_cn])
257
    else:
258
      terms_for_tw[word] = ([simp_tw, trad_tw], [simp_cn, trad_cn])
259
  counts_tw = score_corpus(corpus_tw, terms=terms_for_tw, use_regex=False)
260

261
  return counts_cn, counts_tw
262

263

264
def compute_summary(results: Sequence[TermCounts]) -> float:
265
  """Returns the matched-fraction when summing over the results."""
266
  tally_matched = 0
267
  tally_mismatched = 0
268
  for corpus_counts in results:
269
    for term_pair in corpus_counts.values():
270
      tally_matched += term_pair.matched
271
      tally_mismatched += term_pair.mismatched
272

273
  # Set to zero if there were no hits.
274
  if tally_matched + tally_mismatched == 0:
275
    return 0.0
276

277
  return tally_matched / (tally_matched + tally_mismatched)
278

279

280
def _to_csv(corpus_results: Sequence[TermCounts], lang_codes: Sequence[str],
281
            path: str) -> None:
282
  """Outputs results to CSV, assuming parallel corpus_results & lang_codes."""
283
  assert len(corpus_results) == len(lang_codes), (corpus_results, lang_codes)
284
  fieldnames = ["source_word"]
285
  for lang_code in lang_codes:
286
    fieldnames.append(f"corpus_{lang_code}_matched")
287
    fieldnames.append(f"corpus_{lang_code}_mismatched")
288

289
  # Create a single row for each source language term.
290
  rows = collections.defaultdict(dict)
291
  for corpus_result, lang_code in zip(corpus_results, lang_codes):
292
    for source_word, term_pair in corpus_result.items():
293
      rows[source_word]["source_word"] = source_word
294
      rows[source_word][f"corpus_{lang_code}_matched"] = term_pair.matched
295
      rows[source_word][f"corpus_{lang_code}_mismatched"] = term_pair.mismatched
296

297
  with _open_file(path, "w") as csvfile:
298
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
299
    writer.writeheader()
300
    writer.writerows(rows.values())
301

302

303
def _maybe_read_lines(file_path: Optional[str]) -> Optional[Sequence[str]]:
304
  if file_path is None:
305
    return None
306
  with _open_file(file_path) as file:
307
    return file.readlines()
308

309

310
def _output(summary_metric: float, term_counts: Tuple[TermCounts, TermCounts],
311
            lang_codes: Tuple[str, str], output_path: Optional[str]) -> None:
312
  if output_path is not None:
313
    _to_csv(term_counts, lang_codes, f"{output_path}_terms.csv")
314
    with _open_file(f"{output_path}_lex_acc.txt", mode="wt") as out:
315
      print(summary_metric, file=out)
316

317

318
def run_pt_eval(
319
    corpus_br: Sequence[str],
320
    corpus_pt: Sequence[str],
321
) -> Tuple[float, Tuple[TermCounts, TermCounts]]:
322
  """Runs lexical accuracy evaluation on Portuguese.
323

324
  Includes lowercasing the input corpora.
325

326
  Args:
327
    corpus_br: List of BR-targeted translations, parallel to corpus_pt.
328
    corpus_pt: List of PT-targeted translations, parallel to corpus_br.
329

330
  Returns:
331
    - summary metric
332
    - TermCounts for BR-corpus and PT-corpus, in that order.
333
  """
334

335
  # Lowercase the Portuguese inputs.
336
  corpus_br = [line.strip().lower() for line in corpus_br]
337
  corpus_pt = [line.strip().lower() for line in corpus_pt]
338

339
  assert len(corpus_br) == len(corpus_pt), (
340
      f"{len(corpus_br)} != {len(corpus_pt)}")
341
  term_counts = score_pt(corpus_br=corpus_br, corpus_pt=corpus_pt)
342
  summary_metric = compute_summary(term_counts)
343
  return summary_metric, term_counts
344

345

346
def run_pt_eval_from_files(
347
    corpus_br_path: str,
348
    corpus_pt_path: str,
349
    output_path: Optional[str],
350
) -> Tuple[float, Tuple[TermCounts, TermCounts]]:
351
  """Runs lexical accuracy evaluation on Portuguese from files."""
352
  with _open_file(corpus_br_path) as file:
353
    corpus_br = file.readlines()
354
  with _open_file(corpus_pt_path) as file:
355
    corpus_pt = file.readlines()
356
  logging.info("Read %d BR entries from %s", len(corpus_br), corpus_br_path)
357
  logging.info("Read %d PT entries from %s", len(corpus_pt), corpus_pt_path)
358

359
  summary_metric, term_counts = run_pt_eval(
360
      corpus_br=corpus_br, corpus_pt=corpus_pt)
361

362
  # Literal language codes below follows order of the corpus pair in
363
  # term_counts.
364
  _output(summary_metric, term_counts, ("br", "pt"), output_path)
365
  return summary_metric, term_counts
366

367

368
def run_zh_eval(
369
    corpus_cn: Sequence[str],
370
    corpus_tw: Sequence[str],
371
    script_cn: Optional[ZhScript],
372
    script_tw: Optional[ZhScript],
373
) -> Tuple[float, Tuple[TermCounts, TermCounts]]:
374
  """Runs lexical accuracy evaluation on Chinese.
375

376
  Includes normalizing away spaces in the input corpora.
377

378
  Args:
379
    corpus_cn: List of CN-targeted translations, parallel to corpus_tw.
380
    corpus_tw: List of TW-targeted translations, parallel to corpus_cn.
381
    script_cn: The Chinese script to expect for corpus_cn, determining which
382
      script's term to use for matching. If None, matches against both scripts.
383
    script_tw: The Chinese script to expect for corpus_tw, determining which
384
      script's term to use for matching. If None, matches against both scripts.
385

386
  Returns:
387
    - summary metric
388
    - TermCounts for CN-corpus and TW-corpus, in that order.
389
  """
390
  # Normalize away all spaces, which translators sometimes include in
391
  # mixed-script words.
392
  corpus_cn = [line.strip().replace(" ", "") for line in corpus_cn]
393
  corpus_tw = [line.strip().replace(" ", "") for line in corpus_tw]
394

395
  assert len(corpus_cn) == len(corpus_tw), (
396
      f"{len(corpus_cn)} != {len(corpus_tw)}")
397
  term_counts = score_zh(
398
      corpus_cn=corpus_cn,
399
      corpus_tw=corpus_tw,
400
      script_cn=script_cn,
401
      script_tw=script_tw)
402
  summary_metric = compute_summary(term_counts)
403
  return summary_metric, term_counts
404

405

406
def run_zh_eval_from_files(
407
    corpus_cn_path: str,
408
    corpus_tw_path: str,
409
    script_cn: Optional[ZhScript],
410
    script_tw: Optional[ZhScript],
411
    output_path: Optional[str],
412
) -> Tuple[float, Tuple[TermCounts, TermCounts]]:
413
  """Runs lexical accuracy evaluation on Chinese using file paths."""
414
  with _open_file(corpus_cn_path) as file:
415
    corpus_cn = file.readlines()
416
  with _open_file(corpus_tw_path) as file:
417
    corpus_tw = file.readlines()
418
  logging.info("Read %d CN entries from %s", len(corpus_cn), corpus_cn_path)
419
  logging.info("Read %d TW entries from %s", len(corpus_tw), corpus_tw_path)
420

421
  summary_metric, term_counts = run_zh_eval(
422
      corpus_cn=corpus_cn,
423
      corpus_tw=corpus_tw,
424
      script_cn=script_cn,
425
      script_tw=script_tw)
426

427
  # Literal language codes below follows order of the corpus pair in
428
  # term_counts.
429
  _output(summary_metric, term_counts, ("cn", "tw"), output_path)
430
  return summary_metric, term_counts
431

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.