aurora

Форк
0
144 строки · 4.4 Кб
1
import gc
2
import os
3
import sys
4
import torch
5
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
6
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
7

8
try:
9
    from transformers.utils import (
10
        is_torch_bf16_cpu_available,
11
        is_torch_bf16_gpu_available,
12
        is_torch_cuda_available,
13
        is_torch_npu_available
14
    )
15
    _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
16
    _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
17
except ImportError:
18
    _is_fp16_available = torch.cuda.is_available()
19
    try:
20
        _is_bf16_available = torch.cuda.is_bf16_supported()
21
    except:
22
        _is_bf16_available = False
23

24
if TYPE_CHECKING:
25
    from transformers import HfArgumentParser
26
    from llmtuner.hparams import ModelArguments
27

28

29
class AverageMeter:
30
    r"""
31
    Computes and stores the average and current value.
32
    """
33
    def __init__(self):
34
        self.reset()
35

36
    def reset(self):
37
        self.val = 0
38
        self.avg = 0
39
        self.sum = 0
40
        self.count = 0
41

42
    def update(self, val, n=1):
43
        self.val = val
44
        self.sum += val * n
45
        self.count += n
46
        self.avg = self.sum / self.count
47

48

49
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
50
    r"""
51
    Returns the number of trainable parameters and number of all parameters in the model.
52
    """
53
    trainable_params, all_param = 0, 0
54
    for param in model.parameters():
55
        num_params = param.numel()
56
        # if using DS Zero 3 and the weights are initialized empty
57
        if num_params == 0 and hasattr(param, "ds_numel"):
58
            num_params = param.ds_numel
59

60
        # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
61
        if param.__class__.__name__ == "Params4bit":
62
            num_params = num_params * 2
63

64
        all_param += num_params
65
        if param.requires_grad:
66
            trainable_params += num_params
67

68
    return trainable_params, all_param
69

70

71
def get_current_device() -> torch.device:
72
    import accelerate
73
    if accelerate.utils.is_xpu_available():
74
        device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
75
    elif accelerate.utils.is_npu_available():
76
        device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
77
    elif torch.cuda.is_available():
78
        device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
79
    else:
80
        device = "cpu"
81

82
    return torch.device(device)
83

84

85
def get_logits_processor() -> "LogitsProcessorList":
86
    r"""
87
    Gets logits processor that removes NaN and Inf logits.
88
    """
89
    logits_processor = LogitsProcessorList()
90
    logits_processor.append(InfNanRemoveLogitsProcessor())
91
    return logits_processor
92

93

94
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
95
    r"""
96
    Infers the optimal dtype according to the model_dtype and device compatibility.
97
    """
98
    if _is_bf16_available and model_dtype == torch.bfloat16:
99
        return torch.bfloat16
100
    elif _is_fp16_available:
101
        return torch.float16
102
    else:
103
        return torch.float32
104

105

106
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
107
    if args is not None:
108
        return parser.parse_dict(args)
109
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
110
        return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
111
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
112
        return parser.parse_json_file(os.path.abspath(sys.argv[1]))
113
    else:
114
        return parser.parse_args_into_dataclasses()
115

116

117
def torch_gc() -> None:
118
    r"""
119
    Collects GPU memory.
120
    """
121
    gc.collect()
122
    if torch.cuda.is_available():
123
        torch.cuda.empty_cache()
124
        torch.cuda.ipc_collect()
125

126

127
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
128
    if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
129
        return
130

131
    try:
132
        from modelscope import snapshot_download # type: ignore
133
        revision = "master" if model_args.model_revision == "main" else model_args.model_revision
134
        model_args.model_name_or_path = snapshot_download(
135
            model_args.model_name_or_path,
136
            revision=revision,
137
            cache_dir=model_args.cache_dir
138
        )
139
    except ImportError:
140
        raise ImportError("Please install modelscope via `pip install modelscope -U`")
141

142

143
def use_modelscope() -> bool:
144
    return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
145

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.