5
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
6
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
9
from transformers.utils import (
10
is_torch_bf16_cpu_available,
11
is_torch_bf16_gpu_available,
12
is_torch_cuda_available,
13
is_torch_npu_available
15
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
16
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
18
_is_fp16_available = torch.cuda.is_available()
20
_is_bf16_available = torch.cuda.is_bf16_supported()
22
_is_bf16_available = False
25
from transformers import HfArgumentParser
26
from llmtuner.hparams import ModelArguments
31
Computes and stores the average and current value.
42
def update(self, val, n=1):
46
self.avg = self.sum / self.count
49
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
51
Returns the number of trainable parameters and number of all parameters in the model.
53
trainable_params, all_param = 0, 0
54
for param in model.parameters():
55
num_params = param.numel()
57
if num_params == 0 and hasattr(param, "ds_numel"):
58
num_params = param.ds_numel
61
if param.__class__.__name__ == "Params4bit":
62
num_params = num_params * 2
64
all_param += num_params
65
if param.requires_grad:
66
trainable_params += num_params
68
return trainable_params, all_param
71
def get_current_device() -> torch.device:
73
if accelerate.utils.is_xpu_available():
74
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
75
elif accelerate.utils.is_npu_available():
76
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
77
elif torch.cuda.is_available():
78
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
82
return torch.device(device)
85
def get_logits_processor() -> "LogitsProcessorList":
87
Gets logits processor that removes NaN and Inf logits.
89
logits_processor = LogitsProcessorList()
90
logits_processor.append(InfNanRemoveLogitsProcessor())
91
return logits_processor
94
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
96
Infers the optimal dtype according to the model_dtype and device compatibility.
98
if _is_bf16_available and model_dtype == torch.bfloat16:
100
elif _is_fp16_available:
106
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
108
return parser.parse_dict(args)
109
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
110
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
111
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
112
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
114
return parser.parse_args_into_dataclasses()
117
def torch_gc() -> None:
122
if torch.cuda.is_available():
123
torch.cuda.empty_cache()
124
torch.cuda.ipc_collect()
127
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
128
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
132
from modelscope import snapshot_download
133
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
134
model_args.model_name_or_path = snapshot_download(
135
model_args.model_name_or_path,
137
cache_dir=model_args.cache_dir
140
raise ImportError("Please install modelscope via `pip install modelscope -U`")
143
def use_modelscope() -> bool:
144
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))