google-research
430 строк · 15.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Library to compute simple region-based lexical accuracy metric."""
17
18import collections19import csv20import dataclasses21import enum22import pathlib23import re24from typing import IO, Iterable, Mapping, Optional, Sequence, Tuple, Union25
26from absl import logging27from immutabledict import immutabledict28
29
30# The following maps define the terms of interest, indexed on the original
31# English seed term.
32
33# Format: English: (Simp-CN, Simp-TW, Trad-TW, Trad-CN)
34# Spaces are stripped from the Chinese corpus before matching these terms.
35_CHINESE_TERMS = immutabledict({36"Pineapple": ("菠萝", "凤梨", "鳳梨", "菠蘿"),37"Computer mouse": ("鼠标", "滑鼠", "滑鼠", "鼠標"),38# Original source had CN:牛油果, but translator used 鳄梨.39"Avocado": ("鳄梨", "酪梨", "酪梨", "鱷梨"),40"Band-Aid": ("创可贴", "OK绷", "OK繃", "創可貼"),41"Blog": ("博客", "部落格", "部落格", "博客"),42"New Zealand": ("新西兰", "纽西兰", "紐西蘭", "新西蘭"),43"Printer (computing)": ("打印机", "印表机", "印表機", "打印機"),44# Original source has TW:月臺, but translator used 月台.45"Railway platform": ("站台", "月台", "月台", "站台"),46"Roller coaster": ("过山车", "云霄飞车", "雲霄飛車", "過山車"),47"Salmon": ("三文鱼", "鲑鱼", "鮭魚", "三文魚"),48"Shampoo": ("洗发水", "洗发精", "洗髮精", "洗髮水"),49# From Wikipedia page "Software testing"50"Software": ("软件", "软体", "軟體", "軟件"),51"Sydney": ("悉尼", "雪梨", "雪梨", "悉尼"),52
53# The following two are excluded because they underpin the first 10054# lexical exemplars used for priming the models.55## "Flip-flops": ("人字拖", "夹脚拖", "夾腳拖", "人字拖"),56## "Paper clip": ("回形针", "回纹针", "迴紋針", "回形針"),57})58
59# Portuguese terms.
60# Format: English: (BR, PT)
61# The Portuguese corpus is lowercased before matching these terms.
62_PORTUGUESE_TERMS = immutabledict({63"Bathroom": ("banheiro", "casa de banho"),64# Original source had "pequeno almoço" but translator used "pequeno-almoço".65"Breakfast": ("café da manhã", "pequeno-almoço"),66"Bus": ("ônibus", "autocarro"),67"Cup": ("xícara", "chávena"),68"Computer mouse": ("mouse", "rato"),69"Drivers license": ("carteira de motorista", "carta de condução"),70# From Wikipedia page "Ice cream sandwich"71"Ice cream": ("sorvete", "gelado"),72"Juice": ("suco", "sumo"),73"Mobile phone": ("celular", "telemóvel"),74"Pedestrian": ("pedestre", "peão"),75# From Wikipedia page "Pickpocketing"76"Pickpocket": ("batedor de carteiras", "carteirista"),77"Pineapple": ("abacaxi", "ananás"),78"Refrigerator": ("geladeira", "frigorífico"),79"Suit": ("terno", "fato"),80"Train": ("trem", "comboio"),81"Video game": ("videogame", "videojogos"),82
83# Terms updated after original selection.84
85# For BR, replaced "menina" (common in speech) with "garota" (common in86# writing, matching the human translators.87"Girl": ("garota", "rapariga"),88
89# Replace original "Computer monitor": ("tela de computador", "ecrã") with90# the observed use for just screen:91"Screen": ("tela", "ecrã"),92
93# Terms excluded.94
95# The following three are excluded because they underpin the first 10096# lexical exemplars used for priming the models.97## "Gym": ("academia", "ginásio"),98## "Stapler": ("grampeador", "agrafador"),99## "Nightgown": ("camisola", "camisa de noite"),100
101# The following are excluded for other reasons:102
103# BR translator primarily used 'comissário de bordo' and hardly ever104# 'aeromoça'. PT translator used 'comissários/assistentes de bordo' or just105# 'assistentes de bordo' Excluding the term as low-signal for now.106## "Flight attendant": ("aeromoça", "comissário ao bordo"),107
108# Both regions' translators consistently used "presunto", so the term has109# low signal.110## "Ham": ("presunto", "fiambre"),111})112
113_StrOrPurePath = Union[str, pathlib.PurePath]114
115
116def _open_file(path: _StrOrPurePath, mode: str = "r") -> IO[str]:117return open(path, mode) # pylint: disable=unreachable118
119
120@dataclasses.dataclass(frozen=True)121class TermCount:122matched: int123mismatched: int124
125
126TermCounts = Mapping[str, TermCount]127
128
129class ZhScript(enum.Enum):130SIMPLIFIED = 1131TRADITIONAL = 2132
133
134def _count_term_hits_with_regex(text: str, term: str) -> int:135# Avoids overtriggering when term happens to be a substring in unrelated136# words.137pattern = r"\b" + term + r"\b"138return len(re.findall(pattern, text))139
140
141def _score_terms(142corpus: Sequence[str],143matched_terms: Iterable[str],144mismatched_terms: Iterable[str],145per_example_cap: int = 1,146use_regex: bool = True,147) -> TermCount:148"""Scores term by counting number of non-overlapping substring occurrences."""149matched_total = 0150mismatched_total = 0151
152def _count(sentence: str, term: str) -> int:153if use_regex:154return _count_term_hits_with_regex(sentence, term)155else:156return sentence.count(term)157
158for sentence in corpus:159matched_term_counts = [160_count(sentence, matched_term) for matched_term in matched_terms161]162matched_count = min(sum(matched_term_counts), per_example_cap)163matched_total += matched_count164mismatched_term_counts = [165_count(sentence, mismatched_term)166for mismatched_term in mismatched_terms167]168mismatched_count = min(sum(mismatched_term_counts), per_example_cap)169mismatched_total += mismatched_count170
171for matched_term, matched_term_count in zip(matched_terms,172matched_term_counts):173if matched_term_count > 0:174logging.debug("Hit (match) '%s': %s", matched_term, sentence)175for mismatched_term, mismatched_term_count in zip(mismatched_terms,176mismatched_term_counts):177if mismatched_term_count > 0:178logging.debug("Hit (mismatch) '%s': %s", mismatched_term, sentence)179
180return TermCount(matched=matched_total, mismatched=mismatched_total)181
182
183def score_corpus(184corpus: Iterable[str],185terms: Mapping[str, Tuple[Iterable[str], Iterable[str]]],186use_regex: bool = True,187) -> TermCounts:188r"""Counts occurrences of matching and non-matching terms in corpus.189
190Args:
191corpus: Text to evaluate, in the target language, as an iterable over e.g.
192sentences.
193terms: Map from a source language term to its target language translations
194(x, y) such that x is matched to the regional variant of the corpus, and y
195is mis-matched to the regional variant of the corpus. x and y are
196iterables that may contain alternative orthographic realizations of the
197term variant.
198use_regex: Whether to do term matching using a regex that requires word
199boundaries (\b) around the search term, otherwise uses str.count(). Set to
200False if the corpus is in a non-spaced language like Chinese.
201
202Returns:
203Map from each source language term to a TermCount, recording the number
204of occurrences of the matched and mismatched terms in the corpus.
205"""
206corpus = list(corpus)207return {208source_word:209_score_terms(corpus, matching, mismatching, use_regex=use_regex)210for source_word, (matching, mismatching) in terms.items()211}212
213
214def score_pt(corpus_br: Iterable[str],215corpus_pt: Iterable[str]) -> Tuple[TermCounts, TermCounts]:216"""Calls score_corpus using the hardcoded list of Portuguese terms."""217# _PORTUGUESE_TERMS is already organized as (match, mismatch)-pairs for pt-BR,218# but needs to be converted from strings to lists of strings219counts_br = score_corpus(220corpus_br,221terms={222word: ([br], [pt]) for word, (br, pt) in _PORTUGUESE_TERMS.items()223})224
225# _PORTUGUESE_TERMS must be reorganized as (match, mismatch)-pairs for pt-PT.226counts_pt = score_corpus(227corpus_pt,228terms={229word: ([pt], [br]) for word, (br, pt) in _PORTUGUESE_TERMS.items()230})231
232return counts_br, counts_pt233
234
235def score_zh(corpus_cn: Iterable[str], corpus_tw: Iterable[str],236script_cn: Optional[ZhScript],237script_tw: Optional[ZhScript]) -> Tuple[TermCounts, TermCounts]:238"""Calls score_corpus using the hardcoded list of Chinese terms."""239# Reformat the Chinese term dictionary into (match, mismatch)-pairs for each240# corpus, based on the script of the corpus.241terms_for_cn = {}242for word, (simp_cn, simp_tw, trad_tw, trad_cn) in _CHINESE_TERMS.items():243if script_cn == ZhScript.SIMPLIFIED:244terms_for_cn[word] = ([simp_cn], [simp_tw])245elif script_cn == ZhScript.TRADITIONAL:246terms_for_cn[word] = ([trad_cn], [trad_tw])247else:248terms_for_cn[word] = ([simp_cn, trad_cn], [simp_tw, trad_tw])249counts_cn = score_corpus(corpus_cn, terms=terms_for_cn, use_regex=False)250
251terms_for_tw = {}252for word, (simp_cn, simp_tw, trad_tw, trad_cn) in _CHINESE_TERMS.items():253if script_tw == ZhScript.SIMPLIFIED:254terms_for_tw[word] = ([simp_tw], [simp_cn])255elif script_tw == ZhScript.TRADITIONAL:256terms_for_tw[word] = ([trad_tw], [trad_cn])257else:258terms_for_tw[word] = ([simp_tw, trad_tw], [simp_cn, trad_cn])259counts_tw = score_corpus(corpus_tw, terms=terms_for_tw, use_regex=False)260
261return counts_cn, counts_tw262
263
264def compute_summary(results: Sequence[TermCounts]) -> float:265"""Returns the matched-fraction when summing over the results."""266tally_matched = 0267tally_mismatched = 0268for corpus_counts in results:269for term_pair in corpus_counts.values():270tally_matched += term_pair.matched271tally_mismatched += term_pair.mismatched272
273# Set to zero if there were no hits.274if tally_matched + tally_mismatched == 0:275return 0.0276
277return tally_matched / (tally_matched + tally_mismatched)278
279
280def _to_csv(corpus_results: Sequence[TermCounts], lang_codes: Sequence[str],281path: str) -> None:282"""Outputs results to CSV, assuming parallel corpus_results & lang_codes."""283assert len(corpus_results) == len(lang_codes), (corpus_results, lang_codes)284fieldnames = ["source_word"]285for lang_code in lang_codes:286fieldnames.append(f"corpus_{lang_code}_matched")287fieldnames.append(f"corpus_{lang_code}_mismatched")288
289# Create a single row for each source language term.290rows = collections.defaultdict(dict)291for corpus_result, lang_code in zip(corpus_results, lang_codes):292for source_word, term_pair in corpus_result.items():293rows[source_word]["source_word"] = source_word294rows[source_word][f"corpus_{lang_code}_matched"] = term_pair.matched295rows[source_word][f"corpus_{lang_code}_mismatched"] = term_pair.mismatched296
297with _open_file(path, "w") as csvfile:298writer = csv.DictWriter(csvfile, fieldnames=fieldnames)299writer.writeheader()300writer.writerows(rows.values())301
302
303def _maybe_read_lines(file_path: Optional[str]) -> Optional[Sequence[str]]:304if file_path is None:305return None306with _open_file(file_path) as file:307return file.readlines()308
309
310def _output(summary_metric: float, term_counts: Tuple[TermCounts, TermCounts],311lang_codes: Tuple[str, str], output_path: Optional[str]) -> None:312if output_path is not None:313_to_csv(term_counts, lang_codes, f"{output_path}_terms.csv")314with _open_file(f"{output_path}_lex_acc.txt", mode="wt") as out:315print(summary_metric, file=out)316
317
318def run_pt_eval(319corpus_br: Sequence[str],320corpus_pt: Sequence[str],321) -> Tuple[float, Tuple[TermCounts, TermCounts]]:322"""Runs lexical accuracy evaluation on Portuguese.323
324Includes lowercasing the input corpora.
325
326Args:
327corpus_br: List of BR-targeted translations, parallel to corpus_pt.
328corpus_pt: List of PT-targeted translations, parallel to corpus_br.
329
330Returns:
331- summary metric
332- TermCounts for BR-corpus and PT-corpus, in that order.
333"""
334
335# Lowercase the Portuguese inputs.336corpus_br = [line.strip().lower() for line in corpus_br]337corpus_pt = [line.strip().lower() for line in corpus_pt]338
339assert len(corpus_br) == len(corpus_pt), (340f"{len(corpus_br)} != {len(corpus_pt)}")341term_counts = score_pt(corpus_br=corpus_br, corpus_pt=corpus_pt)342summary_metric = compute_summary(term_counts)343return summary_metric, term_counts344
345
346def run_pt_eval_from_files(347corpus_br_path: str,348corpus_pt_path: str,349output_path: Optional[str],350) -> Tuple[float, Tuple[TermCounts, TermCounts]]:351"""Runs lexical accuracy evaluation on Portuguese from files."""352with _open_file(corpus_br_path) as file:353corpus_br = file.readlines()354with _open_file(corpus_pt_path) as file:355corpus_pt = file.readlines()356logging.info("Read %d BR entries from %s", len(corpus_br), corpus_br_path)357logging.info("Read %d PT entries from %s", len(corpus_pt), corpus_pt_path)358
359summary_metric, term_counts = run_pt_eval(360corpus_br=corpus_br, corpus_pt=corpus_pt)361
362# Literal language codes below follows order of the corpus pair in363# term_counts.364_output(summary_metric, term_counts, ("br", "pt"), output_path)365return summary_metric, term_counts366
367
368def run_zh_eval(369corpus_cn: Sequence[str],370corpus_tw: Sequence[str],371script_cn: Optional[ZhScript],372script_tw: Optional[ZhScript],373) -> Tuple[float, Tuple[TermCounts, TermCounts]]:374"""Runs lexical accuracy evaluation on Chinese.375
376Includes normalizing away spaces in the input corpora.
377
378Args:
379corpus_cn: List of CN-targeted translations, parallel to corpus_tw.
380corpus_tw: List of TW-targeted translations, parallel to corpus_cn.
381script_cn: The Chinese script to expect for corpus_cn, determining which
382script's term to use for matching. If None, matches against both scripts.
383script_tw: The Chinese script to expect for corpus_tw, determining which
384script's term to use for matching. If None, matches against both scripts.
385
386Returns:
387- summary metric
388- TermCounts for CN-corpus and TW-corpus, in that order.
389"""
390# Normalize away all spaces, which translators sometimes include in391# mixed-script words.392corpus_cn = [line.strip().replace(" ", "") for line in corpus_cn]393corpus_tw = [line.strip().replace(" ", "") for line in corpus_tw]394
395assert len(corpus_cn) == len(corpus_tw), (396f"{len(corpus_cn)} != {len(corpus_tw)}")397term_counts = score_zh(398corpus_cn=corpus_cn,399corpus_tw=corpus_tw,400script_cn=script_cn,401script_tw=script_tw)402summary_metric = compute_summary(term_counts)403return summary_metric, term_counts404
405
406def run_zh_eval_from_files(407corpus_cn_path: str,408corpus_tw_path: str,409script_cn: Optional[ZhScript],410script_tw: Optional[ZhScript],411output_path: Optional[str],412) -> Tuple[float, Tuple[TermCounts, TermCounts]]:413"""Runs lexical accuracy evaluation on Chinese using file paths."""414with _open_file(corpus_cn_path) as file:415corpus_cn = file.readlines()416with _open_file(corpus_tw_path) as file:417corpus_tw = file.readlines()418logging.info("Read %d CN entries from %s", len(corpus_cn), corpus_cn_path)419logging.info("Read %d TW entries from %s", len(corpus_tw), corpus_tw_path)420
421summary_metric, term_counts = run_zh_eval(422corpus_cn=corpus_cn,423corpus_tw=corpus_tw,424script_cn=script_cn,425script_tw=script_tw)426
427# Literal language codes below follows order of the corpus pair in428# term_counts.429_output(summary_metric, term_counts, ("cn", "tw"), output_path)430return summary_metric, term_counts431