CSS-LM

Форк
0
/
benchmark_args_utils.py 
130 строк · 5.4 Кб
1
# coding=utf-8
2
# Copyright 2018 The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16

17
import dataclasses
18
import json
19
import logging
20
from dataclasses import dataclass, field
21
from time import time
22
from typing import List
23

24

25
logger = logging.getLogger(__name__)
26

27

28
def list_field(default=None, metadata=None):
29
    return field(default_factory=lambda: default, metadata=metadata)
30

31

32
@dataclass
33
class BenchmarkArguments:
34
    """
35
    BenchMarkArguments are arguments we use in our benchmark scripts
36
    **which relate to the training loop itself**.
37

38
    Using `HfArgumentParser` we can turn this class
39
    into argparse arguments to be able to specify them on
40
    the command line.
41
    """
42

43
    models: List[str] = list_field(
44
        default=[],
45
        metadata={
46
            "help": "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version of all available models"
47
        },
48
    )
49

50
    batch_sizes: List[int] = list_field(
51
        default=[8], metadata={"help": "List of batch sizes for which memory and time performance will be evaluated"}
52
    )
53

54
    sequence_lengths: List[int] = list_field(
55
        default=[8, 32, 128, 512],
56
        metadata={"help": "List of sequence lengths for which memory and time performance will be evaluated"},
57
    )
58

59
    no_inference: bool = field(default=False, metadata={"help": "Don't benchmark inference of model"})
60
    no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"})
61
    no_tpu: bool = field(default=False, metadata={"help": "Whether to run on available tpu devices"})
62
    fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
63
    training: bool = field(default=False, metadata={"help": "Benchmark training of model"})
64
    verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"})
65
    no_speed: bool = field(default=False, metadata={"help": "Don't perform speed measurments"})
66
    no_memory: bool = field(default=False, metadata={"help": "Don't perform memory measurments"})
67
    trace_memory_line_by_line: bool = field(default=False, metadata={"help": "Trace memory line by line"})
68
    save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
69
    log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
70
    no_env_print: bool = field(default=False, metadata={"help": "Don't print environment information"})
71
    no_multi_process: bool = field(
72
        default=False,
73
        metadata={
74
            "help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU."
75
        },
76
    )
77
    inference_time_csv_file: str = field(
78
        default=f"inference_time_{round(time())}.csv",
79
        metadata={"help": "CSV filename used if saving time results to csv."},
80
    )
81
    inference_memory_csv_file: str = field(
82
        default=f"inference_memory_{round(time())}.csv",
83
        metadata={"help": "CSV filename used if saving memory results to csv."},
84
    )
85
    train_time_csv_file: str = field(
86
        default=f"train_time_{round(time())}.csv",
87
        metadata={"help": "CSV filename used if saving time results to csv for training."},
88
    )
89
    train_memory_csv_file: str = field(
90
        default=f"train_memory_{round(time())}.csv",
91
        metadata={"help": "CSV filename used if saving memory results to csv for training."},
92
    )
93
    env_info_csv_file: str = field(
94
        default=f"env_info_{round(time())}.csv",
95
        metadata={"help": "CSV filename used if saving environment information."},
96
    )
97
    log_filename: str = field(
98
        default=f"log_{round(time())}.csv",
99
        metadata={"help": "Log filename used if print statements are saved in log."},
100
    )
101
    repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
102
    only_pretrain_model: bool = field(
103
        default=False,
104
        metadata={
105
            "help": "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain model weights."
106
        },
107
    )
108

109
    def to_json_string(self):
110
        """
111
        Serializes this instance to a JSON string.
112
        """
113
        return json.dumps(dataclasses.asdict(self), indent=2)
114

115
    @property
116
    def model_names(self):
117
        assert (
118
            len(self.models) > 0
119
        ), "Please make sure you provide at least one model name / model identifier, *e.g.* `--models bert-base-cased` or `args.models = ['bert-base-cased']."
120
        return self.models
121

122
    @property
123
    def do_multi_processing(self):
124
        if self.no_multi_process:
125
            return False
126
        elif self.is_tpu:
127
            logger.info("Multiprocessing is currently not possible on TPU.")
128
            return False
129
        else:
130
            return True
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.