lmops

Форк
0
/
qa_utils.py 
110 строк · 3.4 Кб
1
# Copyright 2022 The T5 Authors.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
"""Utilities for Question Answering (QA) evaluation.
16

17
Matches results on the SQuAD (v1.1) and TriviaQA (v1.0) evaluation scripts.
18
"""
19

20
import collections
21
import re
22
import string
23

24
from absl import logging
25
import numpy as np
26

27

28
def _normalize_answer(text, punc_chars, punc_repl):
29
  """Lower text and remove punctuation, articles and extra whitespace."""
30

31
  def remove_articles(s):
32
    return re.sub(r"\b(a|an|the)\b", " ", s)
33

34
  def replace_punctuation(s):
35
    to_replace = set(punc_chars)
36
    return "".join(punc_repl if ch in to_replace else ch for ch in s)
37

38
  def white_space_fix(s):
39
    return " ".join(s.split())
40

41
  text = text.lower()
42
  text = replace_punctuation(text)
43
  text = remove_articles(text)
44
  text = white_space_fix(text)
45
  return text
46

47

48
def normalize_trivia_qa(answer):
49
  """Normalization used in official TriviaQA evaluation script."""
50
  return _normalize_answer(
51
      answer, punc_chars=string.punctuation + "‘’´`_", punc_repl=" ").strip()
52

53

54
def normalize_squad(answer):
55
  """Normalization used in official SQuAD evaluation script."""
56
  return _normalize_answer(answer, punc_chars=string.punctuation, punc_repl="")
57

58

59
def _metric_max_over_ground_truths(metric_fn, ground_truths, prediction):
60
  """Computes the maximum of the metric over all ground truths."""
61
  return max(
62
      metric_fn(ground_truth, prediction) for ground_truth in ground_truths
63
  )
64

65

66
def _exact_match_score(target, prediction):
67
  return target == prediction
68

69

70
def _f1_score(target, prediction):
71
  """Computes token f1 score for a single target and prediction."""
72
  prediction_tokens = prediction.split()
73
  target_tokens = target.split()
74
  common = (collections.Counter(prediction_tokens) &
75
            collections.Counter(target_tokens))
76
  num_same = sum(common.values())
77
  if num_same == 0:
78
    return 0
79
  precision = 1.0 * num_same / len(prediction_tokens)
80
  recall = 1.0 * num_same / len(target_tokens)
81
  f1 = (2 * precision * recall) / (precision + recall)
82
  return f1
83

84
def qa_metrics(targets, predictions, return_list=False):
85
  """Computes exact match and f1 QA scores, expecting pre-normalized text."""
86
  if len(targets) != len(predictions):
87
    raise ValueError("Number of targets and predictions must match.")
88
  if return_list:
89
    em=[
90
        _metric_max_over_ground_truths(_exact_match_score, t, p)
91
        for p, t in zip(predictions, targets)
92
    ]
93
    f1=[
94
        _metric_max_over_ground_truths(_f1_score, t, p)
95
        for p, t in zip(predictions, targets)
96
    ]
97
    return em, f1
98
  em = np.mean([
99
      _metric_max_over_ground_truths(_exact_match_score, t, p)
100
      for p, t in zip(predictions, targets)
101
  ])
102
  f1 = np.mean([
103
      _metric_max_over_ground_truths(_f1_score, t, p)
104
      for p, t in zip(predictions, targets)
105
  ])
106
  em *= 100
107
  f1 *= 100
108
  logging.info("EM = %.2f, F1 = %.2f", em, f1)
109
  #return {"em": em, "f1": f1}
110
  return em, f1

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.