h2o-llmstudio

Форк
0
/
train_wave.py 
163 строки · 5.4 Кб
1
import os
2

3
# Set this before importing any other modules to be on the safe side
4
os.environ["OMP_NUM_THREADS"] = "1"
5
os.environ["MKL_NUM_THREADS"] = "1"
6
os.environ["OPENBLAS_NUM_THREADS"] = "1"
7
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
8
os.environ["NUMEXPR_NUM_THREADS"] = "1"
9
os.environ["TOKENIZERS_PARALLELISM"] = "false"
10

11
import argparse
12
import logging
13
import sys
14
import time
15

16
import psutil
17

18

19
def check_for_done(process_queue):
20
    """Checks for finished process ids
21

22
    Args:
23
        process_queue: list of process ids
24
    Returns:
25
        (True, process_idx) if there is any finished process
26
        (False, False) if there is not finished processes
27
    """
28

29
    for i, pid in enumerate(process_queue):
30
        zombie = False
31
        try:
32
            p = psutil.Process(pid)
33
            zombie = p.status() == "zombie"
34
        except psutil.NoSuchProcess:
35
            pass
36
        if not psutil.pid_exists(pid) or zombie:
37
            return True, i
38

39
    return False, False
40

41

42
if __name__ == "__main__":
43
    parser = argparse.ArgumentParser(description="")
44
    parser.add_argument(
45
        "-C", "--config", help="config filename", default=argparse.SUPPRESS
46
    )
47
    parser.add_argument("-Y", "--yaml", help="yaml filename", default=argparse.SUPPRESS)
48
    parser.add_argument(
49
        "-Q",
50
        "--process-queue",
51
        help="process queue to wait for",
52
        default=argparse.SUPPRESS,
53
    )
54
    parser_args, _ = parser.parse_known_args(sys.argv)
55

56
    process_queue = []
57
    if "process_queue" in parser_args and parser_args.process_queue != "":
58
        process_queue = [int(x) for x in parser_args.process_queue.split(",")]
59

60
    while True:
61
        if len(process_queue) == 0:
62
            break
63
        done, num = check_for_done(process_queue)
64
        if done:
65
            process_queue.pop(num)
66
        else:
67
            time.sleep(30)
68

69
    # delayed imports from llm_studio, only after we want to start training
70
    import subprocess
71

72
    import torch
73

74
    from llm_studio.src.utils.config_utils import load_config_py, load_config_yaml
75
    from llm_studio.src.utils.exceptions import (
76
        LLMAugmentationsException,
77
        LLMDataException,
78
        LLMMetricException,
79
        LLMModelException,
80
        LLMTrainingException,
81
    )
82
    from llm_studio.src.utils.gpu_utils import is_oom_error
83
    from llm_studio.src.utils.logging_utils import initialize_logging, write_flag
84
    from llm_studio.src.utils.utils import kill_ddp_processes
85
    from train import run
86

87
    if "config" in parser_args:
88
        cfg = load_config_py(parser_args.config)
89
    elif "yaml" in parser_args:
90
        cfg = load_config_yaml(parser_args.yaml)
91

92
    flag_path = os.path.join(cfg.output_directory, "flags{}.json")
93

94
    # Check if DDP
95
    if "WORLD_SIZE" in os.environ:
96
        local_rank = int(os.environ["LOCAL_RANK"])
97
        if local_rank == 0:
98
            write_flag(flag_path.format(""), "status", "running")
99
    else:
100
        write_flag(flag_path.format(""), "status", "running")
101
        local_rank = 0
102

103
    initialize_logging(cfg)
104

105
    try:
106
        run(cfg=cfg)
107
    except Exception as exception:
108
        write_flag(flag_path.format(local_rank), "status", "failed")
109
        if is_oom_error(exception):
110
            logging.error(
111
                "GPU Out-of-Memory (OOM) error occurred. "
112
                "Please, reduce the batch size, or input data size, "
113
                "or model size. Or try gradient checkpointing.",
114
                exc_info=True,
115
            )
116
            write_flag(flag_path.format(local_rank), "info", "OOM error")
117

118
            logging.info(
119
                "<pre>"
120
                + subprocess.check_output(["nvidia-smi"]).decode("utf-8")
121
                + "</pre>"
122
            )
123

124
            if torch.cuda.is_available():
125
                logging.info(
126
                    "<pre>" + torch.cuda.memory_summary().replace("-", "=") + "</pre>"
127
                )
128

129
        elif isinstance(exception, LLMDataException):
130
            logging.error(
131
                "Data error occurred during H2O LLM Studio run:", exc_info=True
132
            )
133
            write_flag(flag_path.format(local_rank), "info", "Data error")
134
        elif isinstance(exception, LLMTrainingException):
135
            logging.error(
136
                "Training error occurred during H2O LLM Studio run:", exc_info=True
137
            )
138
            write_flag(flag_path.format(local_rank), "info", "Training error")
139
        elif isinstance(exception, LLMMetricException):
140
            logging.error(
141
                "Validation metric failed. Please make sure selected validation "
142
                "metric is suitable for your current problem setup.",
143
                exc_info=True,
144
            )
145
            write_flag(flag_path.format(local_rank), "info", "Metric error")
146
        elif isinstance(exception, LLMAugmentationsException):
147
            logging.error(
148
                "Custom augmentations error occurred during " "H2O LLM Studio run:",
149
                exc_info=True,
150
            )
151
            write_flag(flag_path.format(local_rank), "info", "Augmentations error")
152
        elif isinstance(exception, LLMModelException):
153
            logging.error(
154
                "Model error occurred during H2O LLM Studio run:",
155
                exc_info=True,
156
            )
157
            write_flag(flag_path.format(local_rank), "info", "Model error")
158
        else:
159
            logging.error(
160
                "Exception occurred during H2O LLM Studio run:", exc_info=True
161
            )
162
            write_flag(flag_path.format(local_rank), "info", "See logs")
163
        kill_ddp_processes()
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.