CSS-LM

Форк
0
/
benchmark_args_tf.py 
105 строк · 3.5 Кб
1
# coding=utf-8
2
# Copyright 2018 The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16

17
import logging
18
from dataclasses import dataclass, field
19
from typing import Tuple
20

21
from ..file_utils import cached_property, is_tf_available, tf_required
22
from .benchmark_args_utils import BenchmarkArguments
23

24

25
if is_tf_available():
26
    import tensorflow as tf
27

28

29
logger = logging.getLogger(__name__)
30

31

32
@dataclass
33
class TensorFlowBenchmarkArguments(BenchmarkArguments):
34
    tpu_name: str = field(
35
        default=None, metadata={"help": "Name of TPU"},
36
    )
37
    device_idx: int = field(
38
        default=0, metadata={"help": "CPU / GPU device index. Defaults to 0."},
39
    )
40
    eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
41
    use_xla: bool = field(
42
        default=False,
43
        metadata={
44
            "help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
45
        },
46
    )
47

48
    @cached_property
49
    @tf_required
50
    def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
51
        if not self.no_tpu:
52
            try:
53
                if self.tpu_name:
54
                    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
55
                else:
56
                    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
57
            except ValueError:
58
                tpu = None
59
        return tpu
60

61
    @cached_property
62
    @tf_required
63
    def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
64
        if self.is_tpu:
65
            tf.config.experimental_connect_to_cluster(self._setup_tpu)
66
            tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
67

68
            strategy = tf.distribute.experimental.TPUStrategy(self._setup_tpu)
69
        else:
70
            # currently no multi gpu is allowed
71
            if self.is_gpu:
72
                # TODO: Currently only single GPU is supported
73
                tf.config.experimental.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
74
                strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
75
            else:
76
                tf.config.experimental.set_visible_devices([], "GPU")  # disable GPU
77
                strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
78

79
        return strategy
80

81
    @property
82
    @tf_required
83
    def is_tpu(self) -> bool:
84
        return self._setup_tpu is not None
85

86
    @property
87
    @tf_required
88
    def strategy(self) -> "tf.distribute.Strategy":
89
        return self._setup_strategy
90

91
    @property
92
    @tf_required
93
    def gpu_list(self):
94
        return tf.config.list_physical_devices("GPU")
95

96
    @property
97
    @tf_required
98
    def n_gpu(self) -> int:
99
        if not self.no_cuda:
100
            return len(self.gpu_list)
101
        return 0
102

103
    @property
104
    def is_gpu(self) -> bool:
105
        return self.n_gpu > 0
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.