CSS-LM
105 строк · 3.5 Кб
1# coding=utf-8
2# Copyright 2018 The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import logging
18from dataclasses import dataclass, field
19from typing import Tuple
20
21from ..file_utils import cached_property, is_tf_available, tf_required
22from .benchmark_args_utils import BenchmarkArguments
23
24
25if is_tf_available():
26import tensorflow as tf
27
28
29logger = logging.getLogger(__name__)
30
31
32@dataclass
33class TensorFlowBenchmarkArguments(BenchmarkArguments):
34tpu_name: str = field(
35default=None, metadata={"help": "Name of TPU"},
36)
37device_idx: int = field(
38default=0, metadata={"help": "CPU / GPU device index. Defaults to 0."},
39)
40eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
41use_xla: bool = field(
42default=False,
43metadata={
44"help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
45},
46)
47
48@cached_property
49@tf_required
50def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
51if not self.no_tpu:
52try:
53if self.tpu_name:
54tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
55else:
56tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
57except ValueError:
58tpu = None
59return tpu
60
61@cached_property
62@tf_required
63def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
64if self.is_tpu:
65tf.config.experimental_connect_to_cluster(self._setup_tpu)
66tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
67
68strategy = tf.distribute.experimental.TPUStrategy(self._setup_tpu)
69else:
70# currently no multi gpu is allowed
71if self.is_gpu:
72# TODO: Currently only single GPU is supported
73tf.config.experimental.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
74strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
75else:
76tf.config.experimental.set_visible_devices([], "GPU") # disable GPU
77strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
78
79return strategy
80
81@property
82@tf_required
83def is_tpu(self) -> bool:
84return self._setup_tpu is not None
85
86@property
87@tf_required
88def strategy(self) -> "tf.distribute.Strategy":
89return self._setup_strategy
90
91@property
92@tf_required
93def gpu_list(self):
94return tf.config.list_physical_devices("GPU")
95
96@property
97@tf_required
98def n_gpu(self) -> int:
99if not self.no_cuda:
100return len(self.gpu_list)
101return 0
102
103@property
104def is_gpu(self) -> bool:
105return self.n_gpu > 0
106