peft

Форк
0
/
testing_utils.py 
115 строк · 3.6 Кб
1
# Copyright 2023-present the HuggingFace Inc. team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import unittest
15
from contextlib import contextmanager
16

17
import numpy as np
18
import pytest
19
import torch
20

21
from peft.import_utils import is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, is_optimum_available
22

23

24
def require_torch_gpu(test_case):
25
    """
26
    Decorator marking a test that requires a GPU. Will be skipped when no GPU is available.
27
    """
28
    if not torch.cuda.is_available():
29
        return unittest.skip("test requires GPU")(test_case)
30
    else:
31
        return test_case
32

33

34
def require_torch_multi_gpu(test_case):
35
    """
36
    Decorator marking a test that requires multiple GPUs. Will be skipped when less than 2 GPUs are available.
37
    """
38
    if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
39
        return unittest.skip("test requires multiple GPUs")(test_case)
40
    else:
41
        return test_case
42

43

44
def require_bitsandbytes(test_case):
45
    """
46
    Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library is not installed.
47
    """
48
    try:
49
        import bitsandbytes  # noqa: F401
50

51
        test_case = pytest.mark.bitsandbytes(test_case)
52
    except ImportError:
53
        test_case = pytest.mark.skip(reason="test requires bitsandbytes")(test_case)
54
    return test_case
55

56

57
def require_auto_gptq(test_case):
58
    """
59
    Decorator marking a test that requires auto-gptq. These tests are skipped when auto-gptq isn't installed.
60
    """
61
    return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
62

63

64
def require_aqlm(test_case):
65
    """
66
    Decorator marking a test that requires aqlm. These tests are skipped when aqlm isn't installed.
67
    """
68
    return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
69

70

71
def require_auto_awq(test_case):
72
    """
73
    Decorator marking a test that requires auto-awq. These tests are skipped when auto-awq isn't installed.
74
    """
75
    return unittest.skipUnless(is_auto_awq_available(), "test requires auto-awq")(test_case)
76

77

78
def require_optimum(test_case):
79
    """
80
    Decorator marking a test that requires optimum. These tests are skipped when optimum isn't installed.
81
    """
82
    return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)
83

84

85
@contextmanager
86
def temp_seed(seed: int):
87
    """Temporarily set the random seed. This works for python numpy, pytorch."""
88

89
    np_state = np.random.get_state()
90
    np.random.seed(seed)
91

92
    torch_state = torch.random.get_rng_state()
93
    torch.random.manual_seed(seed)
94

95
    if torch.cuda.is_available():
96
        torch_cuda_states = torch.cuda.get_rng_state_all()
97
        torch.cuda.manual_seed_all(seed)
98

99
    try:
100
        yield
101
    finally:
102
        np.random.set_state(np_state)
103

104
        torch.random.set_rng_state(torch_state)
105
        if torch.cuda.is_available():
106
            torch.cuda.set_rng_state_all(torch_cuda_states)
107

108

109
def get_state_dict(model, unwrap_compiled=True):
110
    """
111
    Get the state dict of a model. If the model is compiled, unwrap it first.
112
    """
113
    if unwrap_compiled:
114
        model = getattr(model, "_orig_mod", model)
115
    return model.state_dict()
116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.