h2o-llmstudio
/
train_wave.py
163 строки · 5.4 Кб
1import os2
3# Set this before importing any other modules to be on the safe side
4os.environ["OMP_NUM_THREADS"] = "1"5os.environ["MKL_NUM_THREADS"] = "1"6os.environ["OPENBLAS_NUM_THREADS"] = "1"7os.environ["VECLIB_MAXIMUM_THREADS"] = "1"8os.environ["NUMEXPR_NUM_THREADS"] = "1"9os.environ["TOKENIZERS_PARALLELISM"] = "false"10
11import argparse12import logging13import sys14import time15
16import psutil17
18
19def check_for_done(process_queue):20"""Checks for finished process ids21
22Args:
23process_queue: list of process ids
24Returns:
25(True, process_idx) if there is any finished process
26(False, False) if there is not finished processes
27"""
28
29for i, pid in enumerate(process_queue):30zombie = False31try:32p = psutil.Process(pid)33zombie = p.status() == "zombie"34except psutil.NoSuchProcess:35pass36if not psutil.pid_exists(pid) or zombie:37return True, i38
39return False, False40
41
42if __name__ == "__main__":43parser = argparse.ArgumentParser(description="")44parser.add_argument(45"-C", "--config", help="config filename", default=argparse.SUPPRESS46)47parser.add_argument("-Y", "--yaml", help="yaml filename", default=argparse.SUPPRESS)48parser.add_argument(49"-Q",50"--process-queue",51help="process queue to wait for",52default=argparse.SUPPRESS,53)54parser_args, _ = parser.parse_known_args(sys.argv)55
56process_queue = []57if "process_queue" in parser_args and parser_args.process_queue != "":58process_queue = [int(x) for x in parser_args.process_queue.split(",")]59
60while True:61if len(process_queue) == 0:62break63done, num = check_for_done(process_queue)64if done:65process_queue.pop(num)66else:67time.sleep(30)68
69# delayed imports from llm_studio, only after we want to start training70import subprocess71
72import torch73
74from llm_studio.src.utils.config_utils import load_config_py, load_config_yaml75from llm_studio.src.utils.exceptions import (76LLMAugmentationsException,77LLMDataException,78LLMMetricException,79LLMModelException,80LLMTrainingException,81)82from llm_studio.src.utils.gpu_utils import is_oom_error83from llm_studio.src.utils.logging_utils import initialize_logging, write_flag84from llm_studio.src.utils.utils import kill_ddp_processes85from train import run86
87if "config" in parser_args:88cfg = load_config_py(parser_args.config)89elif "yaml" in parser_args:90cfg = load_config_yaml(parser_args.yaml)91
92flag_path = os.path.join(cfg.output_directory, "flags{}.json")93
94# Check if DDP95if "WORLD_SIZE" in os.environ:96local_rank = int(os.environ["LOCAL_RANK"])97if local_rank == 0:98write_flag(flag_path.format(""), "status", "running")99else:100write_flag(flag_path.format(""), "status", "running")101local_rank = 0102
103initialize_logging(cfg)104
105try:106run(cfg=cfg)107except Exception as exception:108write_flag(flag_path.format(local_rank), "status", "failed")109if is_oom_error(exception):110logging.error(111"GPU Out-of-Memory (OOM) error occurred. "112"Please, reduce the batch size, or input data size, "113"or model size. Or try gradient checkpointing.",114exc_info=True,115)116write_flag(flag_path.format(local_rank), "info", "OOM error")117
118logging.info(119"<pre>"120+ subprocess.check_output(["nvidia-smi"]).decode("utf-8")121+ "</pre>"122)123
124if torch.cuda.is_available():125logging.info(126"<pre>" + torch.cuda.memory_summary().replace("-", "=") + "</pre>"127)128
129elif isinstance(exception, LLMDataException):130logging.error(131"Data error occurred during H2O LLM Studio run:", exc_info=True132)133write_flag(flag_path.format(local_rank), "info", "Data error")134elif isinstance(exception, LLMTrainingException):135logging.error(136"Training error occurred during H2O LLM Studio run:", exc_info=True137)138write_flag(flag_path.format(local_rank), "info", "Training error")139elif isinstance(exception, LLMMetricException):140logging.error(141"Validation metric failed. Please make sure selected validation "142"metric is suitable for your current problem setup.",143exc_info=True,144)145write_flag(flag_path.format(local_rank), "info", "Metric error")146elif isinstance(exception, LLMAugmentationsException):147logging.error(148"Custom augmentations error occurred during " "H2O LLM Studio run:",149exc_info=True,150)151write_flag(flag_path.format(local_rank), "info", "Augmentations error")152elif isinstance(exception, LLMModelException):153logging.error(154"Model error occurred during H2O LLM Studio run:",155exc_info=True,156)157write_flag(flag_path.format(local_rank), "info", "Model error")158else:159logging.error(160"Exception occurred during H2O LLM Studio run:", exc_info=True161)162write_flag(flag_path.format(local_rank), "info", "See logs")163kill_ddp_processes()164