google-research
104 строки · 3.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Code for weighting examples from different tasks based on dataset sizes."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22
23def _multiples_and_weights(config):24"""Helper for weighting GLUE datasets.25
26Concatenating all the train sets together and then shuffling the examples
27causes large datasets to dominate the training, resulting in poor performance
28on small datasets. This has some hacky logic to produce (1) "multiples" for
29each dataset so the multi-task train set contains small datasets multiple
30times, so those examples are seen more often and (2) weights for each dataset,
31which also allows for smaller datasets to have influence on training.
32Overall the effect is that tasks are weighted according to
33dataset_size^config.task_weight_exponent.
34
35Args:
36config: a configure.Config object
37Returns:
38How many copies and weights for each dataset.
39"""
40
41dataset_sizes = {42"cola": 8551,43"mnli": 392702,44"mrpc": 7336,45"qnli": 108436,46"qqp": 363869,47"sst": 67349,48"sts": 11498,49"rte": 249050}51
52def map_values(f, d):53return {k: f(v) for k, v in d.items()}54
55def map_kv(f, d):56return {k: f(k, v) for k, v in d.items()}57
58def normalize(d):59total = float(sum(d.values()))60return map_values(lambda v: v / total, d)61
62dataset_weights = map_values(lambda s: s ** config.task_weight_exponent,63dataset_sizes)64dataset_weights = normalize(dataset_weights)65correction = dataset_sizes["mnli"] / dataset_weights["mnli"]66dataset_tgts = map_values(lambda v: v * correction, dataset_weights)67dataset_multiples = map_kv(68lambda task, tgt: round((tgt + 0.01) / dataset_sizes[task]), dataset_tgts)69new_dataset_sizes = map_kv(70lambda task, multiple: dataset_sizes[task] * multiple, dataset_multiples)71weights_after_multiples = map_values(72lambda v: v * len(dataset_sizes),73normalize({task: dataset_weights[task] / new_dataset_sizes[task]74for task in new_dataset_sizes}))75
76return dataset_multiples, weights_after_multiples77
78
79def get_task_multiple(task, split):80if split != "train":81return 182if task.config.dataset_multiples:83multiples, _ = _multiples_and_weights(task.config)84return int(multiples[task.name] + 1e-5)85return 186
87
88def get_task_weights(config, sizes):89"""Get task weights according to dataset sizes."""90
91if config.dataset_multiples:92_, weights = _multiples_and_weights(config)93return weights94else:95if config.task_weight_exponent < 0:96return {task_name: 1.0 for task_name in sizes}97n_examples = sum(sizes.values())98weights = {task_name: 1.0 / (size**(1 - config.task_weight_exponent))99for task_name, size in sizes.items()}100expected_weight = sum([weights[task_name] * sizes[task_name] / n_examples101for task_name in weights])102weights = {task_name: w / expected_weight103for task_name, w in weights.items()}104return weights105