CSS-LM

Форк
0
/
benchmark_args.py 
86 строк · 2.6 Кб
1
# coding=utf-8
2
# Copyright 2018 The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16

17
import logging
18
from dataclasses import dataclass, field
19
from typing import Tuple
20

21
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
22
from .benchmark_args_utils import BenchmarkArguments
23

24

25
if is_torch_available():
26
    import torch
27

28
if is_torch_tpu_available():
29
    import torch_xla.core.xla_model as xm
30

31

32
logger = logging.getLogger(__name__)
33

34

35
@dataclass
36
class PyTorchBenchmarkArguments(BenchmarkArguments):
37
    torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
38
    torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"})
39
    fp16_opt_level: str = field(
40
        default="O1",
41
        metadata={
42
            "help": (
43
                "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
44
                "See details at https://nvidia.github.io/apex/amp.html"
45
            )
46
        },
47
    )
48

49
    @cached_property
50
    @torch_required
51
    def _setup_devices(self) -> Tuple["torch.device", int]:
52
        logger.info("PyTorch: setting up devices")
53
        if self.no_cuda:
54
            device = torch.device("cpu")
55
            n_gpu = 0
56
        elif is_torch_tpu_available():
57
            device = xm.xla_device()
58
            n_gpu = 0
59
        else:
60
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
            n_gpu = torch.cuda.device_count()
62
        return device, n_gpu
63

64
    @property
65
    def is_tpu(self):
66
        return is_torch_tpu_available() and not self.no_tpu
67

68
    @property
69
    @torch_required
70
    def device_idx(self) -> int:
71
        # TODO(PVP): currently only single GPU is supported
72
        return torch.cuda.current_device()
73

74
    @property
75
    @torch_required
76
    def device(self) -> "torch.device":
77
        return self._setup_devices[0]
78

79
    @property
80
    @torch_required
81
    def n_gpu(self):
82
        return self._setup_devices[1]
83

84
    @property
85
    def is_gpu(self):
86
        return self.n_gpu > 0
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.