CSS-LM
86 строк · 2.6 Кб
1# coding=utf-8
2# Copyright 2018 The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import logging
18from dataclasses import dataclass, field
19from typing import Tuple
20
21from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
22from .benchmark_args_utils import BenchmarkArguments
23
24
25if is_torch_available():
26import torch
27
28if is_torch_tpu_available():
29import torch_xla.core.xla_model as xm
30
31
32logger = logging.getLogger(__name__)
33
34
35@dataclass
36class PyTorchBenchmarkArguments(BenchmarkArguments):
37torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
38torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"})
39fp16_opt_level: str = field(
40default="O1",
41metadata={
42"help": (
43"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
44"See details at https://nvidia.github.io/apex/amp.html"
45)
46},
47)
48
49@cached_property
50@torch_required
51def _setup_devices(self) -> Tuple["torch.device", int]:
52logger.info("PyTorch: setting up devices")
53if self.no_cuda:
54device = torch.device("cpu")
55n_gpu = 0
56elif is_torch_tpu_available():
57device = xm.xla_device()
58n_gpu = 0
59else:
60device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61n_gpu = torch.cuda.device_count()
62return device, n_gpu
63
64@property
65def is_tpu(self):
66return is_torch_tpu_available() and not self.no_tpu
67
68@property
69@torch_required
70def device_idx(self) -> int:
71# TODO(PVP): currently only single GPU is supported
72return torch.cuda.current_device()
73
74@property
75@torch_required
76def device(self) -> "torch.device":
77return self._setup_devices[0]
78
79@property
80@torch_required
81def n_gpu(self):
82return self._setup_devices[1]
83
84@property
85def is_gpu(self):
86return self.n_gpu > 0
87