paddlenlp

Форк
0
60 строк · 1.9 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import paddle
16

17
from .log import logger
18

19

20
def get_device_and_mapping():
21
    """
22
    Return device type and name-bool mapping implifying which type is supported.
23
    """
24
    suppoted_device_map = {
25
        "gpu": paddle.is_compiled_with_cuda(),
26
        "xpu": paddle.is_compiled_with_xpu(),
27
        "rocm": paddle.is_compiled_with_rocm(),
28
        "npu": paddle.is_compiled_with_custom_device("npu"),
29
        "cpu": True,
30
    }
31
    for d, v in suppoted_device_map.items():
32
        if v:
33
            return d, suppoted_device_map
34

35

36
def get_device():
37
    """
38
    Return the device with which the paddle is compiled, including 'gpu'(for rocm and gpu), 'npu', 'xpu', 'cpu'.
39
    """
40
    d, _ = get_device_and_mapping()
41
    return d
42

43

44
def synchronize():
45
    """
46
    Synchronize device, return True if succeeded, otherwise return False
47
    """
48
    device = paddle.get_device().split(":")[0]
49
    if device in ["gpu", "rocm"]:
50
        paddle.device.cuda.synchronize()
51
        return True
52
    elif device == "xpu":
53
        paddle.device.xpu.synchronize()
54
        return True
55
    elif device in paddle.device.get_all_custom_device_type():
56
        paddle.device.synchronize()
57
        return True
58
    else:
59
        logger.warning("The synchronization is only supported on cuda and xpu now.")
60
    return False
61

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.