CSS-LM
130 строк · 5.4 Кб
1# coding=utf-8
2# Copyright 2018 The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import dataclasses
18import json
19import logging
20from dataclasses import dataclass, field
21from time import time
22from typing import List
23
24
25logger = logging.getLogger(__name__)
26
27
28def list_field(default=None, metadata=None):
29return field(default_factory=lambda: default, metadata=metadata)
30
31
32@dataclass
33class BenchmarkArguments:
34"""
35BenchMarkArguments are arguments we use in our benchmark scripts
36**which relate to the training loop itself**.
37
38Using `HfArgumentParser` we can turn this class
39into argparse arguments to be able to specify them on
40the command line.
41"""
42
43models: List[str] = list_field(
44default=[],
45metadata={
46"help": "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version of all available models"
47},
48)
49
50batch_sizes: List[int] = list_field(
51default=[8], metadata={"help": "List of batch sizes for which memory and time performance will be evaluated"}
52)
53
54sequence_lengths: List[int] = list_field(
55default=[8, 32, 128, 512],
56metadata={"help": "List of sequence lengths for which memory and time performance will be evaluated"},
57)
58
59no_inference: bool = field(default=False, metadata={"help": "Don't benchmark inference of model"})
60no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"})
61no_tpu: bool = field(default=False, metadata={"help": "Whether to run on available tpu devices"})
62fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
63training: bool = field(default=False, metadata={"help": "Benchmark training of model"})
64verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"})
65no_speed: bool = field(default=False, metadata={"help": "Don't perform speed measurments"})
66no_memory: bool = field(default=False, metadata={"help": "Don't perform memory measurments"})
67trace_memory_line_by_line: bool = field(default=False, metadata={"help": "Trace memory line by line"})
68save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
69log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
70no_env_print: bool = field(default=False, metadata={"help": "Don't print environment information"})
71no_multi_process: bool = field(
72default=False,
73metadata={
74"help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU."
75},
76)
77inference_time_csv_file: str = field(
78default=f"inference_time_{round(time())}.csv",
79metadata={"help": "CSV filename used if saving time results to csv."},
80)
81inference_memory_csv_file: str = field(
82default=f"inference_memory_{round(time())}.csv",
83metadata={"help": "CSV filename used if saving memory results to csv."},
84)
85train_time_csv_file: str = field(
86default=f"train_time_{round(time())}.csv",
87metadata={"help": "CSV filename used if saving time results to csv for training."},
88)
89train_memory_csv_file: str = field(
90default=f"train_memory_{round(time())}.csv",
91metadata={"help": "CSV filename used if saving memory results to csv for training."},
92)
93env_info_csv_file: str = field(
94default=f"env_info_{round(time())}.csv",
95metadata={"help": "CSV filename used if saving environment information."},
96)
97log_filename: str = field(
98default=f"log_{round(time())}.csv",
99metadata={"help": "Log filename used if print statements are saved in log."},
100)
101repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
102only_pretrain_model: bool = field(
103default=False,
104metadata={
105"help": "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain model weights."
106},
107)
108
109def to_json_string(self):
110"""
111Serializes this instance to a JSON string.
112"""
113return json.dumps(dataclasses.asdict(self), indent=2)
114
115@property
116def model_names(self):
117assert (
118len(self.models) > 0
119), "Please make sure you provide at least one model name / model identifier, *e.g.* `--models bert-base-cased` or `args.models = ['bert-base-cased']."
120return self.models
121
122@property
123def do_multi_processing(self):
124if self.no_multi_process:
125return False
126elif self.is_tpu:
127logger.info("Multiprocessing is currently not possible on TPU.")
128return False
129else:
130return True
131