15
from contextlib import contextmanager
21
from peft.import_utils import is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, is_optimum_available
24
def require_torch_gpu(test_case):
26
Decorator marking a test that requires a GPU. Will be skipped when no GPU is available.
28
if not torch.cuda.is_available():
29
return unittest.skip("test requires GPU")(test_case)
34
def require_torch_multi_gpu(test_case):
36
Decorator marking a test that requires multiple GPUs. Will be skipped when less than 2 GPUs are available.
38
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
39
return unittest.skip("test requires multiple GPUs")(test_case)
44
def require_bitsandbytes(test_case):
46
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library is not installed.
51
test_case = pytest.mark.bitsandbytes(test_case)
53
test_case = pytest.mark.skip(reason="test requires bitsandbytes")(test_case)
57
def require_auto_gptq(test_case):
59
Decorator marking a test that requires auto-gptq. These tests are skipped when auto-gptq isn't installed.
61
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
64
def require_aqlm(test_case):
66
Decorator marking a test that requires aqlm. These tests are skipped when aqlm isn't installed.
68
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
71
def require_auto_awq(test_case):
73
Decorator marking a test that requires auto-awq. These tests are skipped when auto-awq isn't installed.
75
return unittest.skipUnless(is_auto_awq_available(), "test requires auto-awq")(test_case)
78
def require_optimum(test_case):
80
Decorator marking a test that requires optimum. These tests are skipped when optimum isn't installed.
82
return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)
86
def temp_seed(seed: int):
87
"""Temporarily set the random seed. This works for python numpy, pytorch."""
89
np_state = np.random.get_state()
92
torch_state = torch.random.get_rng_state()
93
torch.random.manual_seed(seed)
95
if torch.cuda.is_available():
96
torch_cuda_states = torch.cuda.get_rng_state_all()
97
torch.cuda.manual_seed_all(seed)
102
np.random.set_state(np_state)
104
torch.random.set_rng_state(torch_state)
105
if torch.cuda.is_available():
106
torch.cuda.set_rng_state_all(torch_cuda_states)
109
def get_state_dict(model, unwrap_compiled=True):
111
Get the state dict of a model. If the model is compiled, unwrap it first.
114
model = getattr(model, "_orig_mod", model)
115
return model.state_dict()