CSS-LM
294 строки · 12.5 Кб
1# coding=utf-8
2# Copyright 2018 The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""
17Benchmarking the library on inference and training in PyTorch.
18"""
19
20
21import logging
22import random
23import timeit
24from functools import wraps
25from typing import Callable, Optional
26
27from transformers import (
28TF_MODEL_MAPPING,
29TF_MODEL_WITH_LM_HEAD_MAPPING,
30PretrainedConfig,
31is_py3nvml_available,
32is_tf_available,
33)
34
35from .benchmark_utils import (
36Benchmark,
37Memory,
38MemorySummary,
39measure_peak_memory_cpu,
40start_memory_tracing,
41stop_memory_tracing,
42)
43
44
45if is_tf_available():
46import tensorflow as tf
47from .benchmark_args_tf import TensorFlowBenchmarkArguments
48from tensorflow.python.framework.errors_impl import ResourceExhaustedError
49
50if is_py3nvml_available():
51import py3nvml.py3nvml as nvml
52
53logger = logging.getLogger(__name__)
54
55
56def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
57def run_func(func):
58@wraps(func)
59def run_in_eager_mode(*args, **kwargs):
60return func(*args, **kwargs)
61
62@wraps(func)
63@tf.function(experimental_compile=use_xla)
64def run_in_graph_mode(*args, **kwargs):
65return func(*args, **kwargs)
66
67if do_eager_mode is True:
68assert (
69use_xla is False
70), "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
71return run_in_eager_mode
72else:
73return run_in_graph_mode
74
75return run_func
76
77
78def random_input_ids(batch_size: int, sequence_length: int, vocab_size: int) -> ["tf.Tensor"]:
79rng = random.Random()
80values = [rng.randint(0, vocab_size - 1) for i in range(batch_size * sequence_length)]
81return tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
82
83
84class TensorFlowBenchmark(Benchmark):
85
86args: TensorFlowBenchmarkArguments
87configs: PretrainedConfig
88framework: str = "TensorFlow"
89
90@property
91def framework_version(self):
92return tf.__version__
93
94def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
95# initialize GPU on separate process
96strategy = self.args.strategy
97assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
98_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
99return self._measure_speed(_inference)
100
101def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
102strategy = self.args.strategy
103assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
104_train = self._prepare_train_func(model_name, batch_size, sequence_length)
105return self._measure_speed(_train)
106
107def _inference_memory(
108self, model_name: str, batch_size: int, sequence_length: int
109) -> [Memory, Optional[MemorySummary]]:
110# initialize GPU on separate process
111if self.args.is_gpu:
112tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
113strategy = self.args.strategy
114assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
115_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
116return self._measure_memory(_inference)
117
118def _train_memory(
119self, model_name: str, batch_size: int, sequence_length: int
120) -> [Memory, Optional[MemorySummary]]:
121if self.args.is_gpu:
122tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
123strategy = self.args.strategy
124assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
125
126_train = self._prepare_train_func(model_name, batch_size, sequence_length)
127return self._measure_memory(_train)
128
129def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
130config = self.config_dict[model_name]
131
132if self.args.fp16:
133raise NotImplementedError("Mixed precision is currently not supported.")
134
135has_model_class_in_config = (
136hasattr(config, "architectures")
137and isinstance(config.architectures, list)
138and len(config.architectures) > 0
139)
140if not self.args.only_pretrain_model and has_model_class_in_config:
141try:
142model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
143transformers_module = __import__("transformers", fromlist=[model_class])
144model_cls = getattr(transformers_module, model_class)
145model = model_cls(config)
146except ImportError:
147raise ImportError(
148f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
149)
150else:
151model = TF_MODEL_MAPPING[config.__class__](config)
152
153# encoder-decoder has vocab size saved differently
154vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
155input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
156
157@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
158def encoder_decoder_forward():
159return model(input_ids, decoder_input_ids=input_ids, training=False)
160
161@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
162def encoder_forward():
163return model(input_ids, training=False)
164
165_inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
166
167return _inference
168
169def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
170config = self.config_dict[model_name]
171
172assert (
173self.args.eager_mode is False
174), "Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`."
175
176if self.args.fp16:
177raise NotImplementedError("Mixed precision is currently not supported.")
178
179has_model_class_in_config = (
180hasattr(config, "architectures")
181and isinstance(config.architectures, list)
182and len(config.architectures) > 0
183)
184if not self.args.only_pretrain_model and has_model_class_in_config:
185try:
186model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
187transformers_module = __import__("transformers", fromlist=[model_class])
188model_cls = getattr(transformers_module, model_class)
189model = model_cls(config)
190except ImportError:
191raise ImportError(
192f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
193)
194else:
195model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
196
197# encoder-decoder has vocab size saved differently
198vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
199input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
200
201@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
202def encoder_decoder_train():
203loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids, training=True)[0]
204gradients = tf.gradients(loss, model.trainable_variables)
205return gradients
206
207@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
208def encoder_train():
209loss = model(input_ids, labels=input_ids, training=True)[0]
210gradients = tf.gradients(loss, model.trainable_variables)
211return gradients
212
213_train = encoder_decoder_train if config.is_encoder_decoder else encoder_train
214
215return _train
216
217def _measure_speed(self, func) -> float:
218with self.args.strategy.scope():
219try:
220if self.args.is_tpu or self.args.use_xla:
221# run additional 10 times to stabilize compilation for tpu
222logger.info("Do inference on TPU. Running model 5 times to stabilize compilation")
223timeit.repeat(func, repeat=1, number=5)
224
225# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
226runtimes = timeit.repeat(func, repeat=self.args.repeat, number=10,)
227
228return min(runtimes) / 10.0
229except ResourceExhaustedError as e:
230self.print_fn("Doesn't fit on GPU. {}".format(e))
231
232def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
233logger.info(
234"Note that TensorFlow allocates more memory than"
235"it might need to speed up computation."
236"The memory reported here corresponds to the memory"
237"reported by `nvidia-smi`, which can vary depending"
238"on total available memory on the GPU that is used."
239)
240with self.args.strategy.scope():
241try:
242if self.args.trace_memory_line_by_line:
243assert (
244self.args.eager_mode
245), "`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory consumption line by line."
246trace = start_memory_tracing("transformers")
247
248if self.args.is_tpu:
249# tpu
250raise NotImplementedError(
251"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`"
252)
253elif self.args.is_gpu:
254# gpu
255if not is_py3nvml_available():
256logger.warning(
257"py3nvml not installed, we won't log GPU memory usage. "
258"Install py3nvml (pip install py3nvml) to log information about GPU."
259)
260memory = "N/A"
261else:
262logger.info(
263"Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU."
264)
265# init nvml
266nvml.nvmlInit()
267func()
268handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
269meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
270max_bytes_in_use = meminfo.used
271memory = Memory(max_bytes_in_use)
272# shutdown nvml
273nvml.nvmlShutdown()
274else:
275# cpu
276if self.args.trace_memory_line_by_line:
277logger.info(
278"When enabling line by line tracing, the max peak memory for CPU is inaccurate in TensorFlow."
279)
280memory = None
281else:
282memory_bytes = measure_peak_memory_cpu(func)
283memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
284if self.args.trace_memory_line_by_line:
285summary = stop_memory_tracing(trace)
286if memory is None:
287memory = summary.total
288else:
289summary = None
290
291return memory, summary
292except ResourceExhaustedError as e:
293self.print_fn("Doesn't fit on GPU. {}".format(e))
294return "N/A", None
295