google-research

Форк
0
/
task_weighting.py 
104 строки · 3.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Code for weighting examples from different tasks based on dataset sizes."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22

23
def _multiples_and_weights(config):
24
  """Helper for weighting GLUE datasets.
25

26
  Concatenating all the train sets together and then shuffling the examples
27
  causes large datasets to dominate the training, resulting in poor performance
28
  on small datasets. This has some hacky logic to produce (1) "multiples" for
29
  each dataset so the multi-task train set contains small datasets multiple
30
  times, so those examples are seen more often and (2) weights for each dataset,
31
  which also allows for smaller datasets to have influence on training.
32
  Overall the effect is that tasks are weighted according to
33
  dataset_size^config.task_weight_exponent.
34

35
  Args:
36
    config: a configure.Config object
37
  Returns:
38
    How many copies and weights for each dataset.
39
  """
40

41
  dataset_sizes = {
42
      "cola": 8551,
43
      "mnli": 392702,
44
      "mrpc": 7336,
45
      "qnli": 108436,
46
      "qqp": 363869,
47
      "sst": 67349,
48
      "sts": 11498,
49
      "rte": 2490
50
  }
51

52
  def map_values(f, d):
53
    return {k: f(v) for k, v in d.items()}
54

55
  def map_kv(f, d):
56
    return {k: f(k, v) for k, v in d.items()}
57

58
  def normalize(d):
59
    total = float(sum(d.values()))
60
    return map_values(lambda v: v / total, d)
61

62
  dataset_weights = map_values(lambda s: s ** config.task_weight_exponent,
63
                               dataset_sizes)
64
  dataset_weights = normalize(dataset_weights)
65
  correction = dataset_sizes["mnli"] / dataset_weights["mnli"]
66
  dataset_tgts = map_values(lambda v: v * correction, dataset_weights)
67
  dataset_multiples = map_kv(
68
      lambda task, tgt: round((tgt + 0.01) / dataset_sizes[task]), dataset_tgts)
69
  new_dataset_sizes = map_kv(
70
      lambda task, multiple: dataset_sizes[task] * multiple, dataset_multiples)
71
  weights_after_multiples = map_values(
72
      lambda v: v * len(dataset_sizes),
73
      normalize({task: dataset_weights[task] / new_dataset_sizes[task]
74
                 for task in new_dataset_sizes}))
75

76
  return dataset_multiples, weights_after_multiples
77

78

79
def get_task_multiple(task, split):
80
  if split != "train":
81
    return 1
82
  if task.config.dataset_multiples:
83
    multiples, _ = _multiples_and_weights(task.config)
84
    return int(multiples[task.name] + 1e-5)
85
  return 1
86

87

88
def get_task_weights(config, sizes):
89
  """Get task weights according to dataset sizes."""
90

91
  if config.dataset_multiples:
92
    _, weights = _multiples_and_weights(config)
93
    return weights
94
  else:
95
    if config.task_weight_exponent < 0:
96
      return {task_name: 1.0 for task_name in sizes}
97
    n_examples = sum(sizes.values())
98
    weights = {task_name: 1.0 / (size**(1 - config.task_weight_exponent))
99
               for task_name, size in sizes.items()}
100
    expected_weight = sum([weights[task_name] * sizes[task_name] / n_examples
101
                           for task_name in weights])
102
    weights = {task_name: w / expected_weight
103
               for task_name, w in weights.items()}
104
    return weights
105

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.