lama

Форк
0
/
train.py 
73 строки · 2.6 Кб
1
#!/usr/bin/env python3
2

3
import logging
4
import os
5
import sys
6
import traceback
7

8
os.environ['OMP_NUM_THREADS'] = '1'
9
os.environ['OPENBLAS_NUM_THREADS'] = '1'
10
os.environ['MKL_NUM_THREADS'] = '1'
11
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
12
os.environ['NUMEXPR_NUM_THREADS'] = '1'
13

14
import hydra
15
from omegaconf import OmegaConf
16
from pytorch_lightning import Trainer
17
from pytorch_lightning.callbacks import ModelCheckpoint
18
from pytorch_lightning.loggers import TensorBoardLogger
19
from pytorch_lightning.plugins import DDPPlugin
20

21
from saicinpainting.training.trainers import make_training_model
22
from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \
23
    handle_deterministic_config
24

25
LOGGER = logging.getLogger(__name__)
26

27

28
@handle_ddp_subprocess()
29
@hydra.main(config_path='../configs/training', config_name='tiny_test.yaml')
30
def main(config: OmegaConf):
31
    try:
32
        need_set_deterministic = handle_deterministic_config(config)
33

34
        if sys.platform != 'win32':
35
            register_debug_signal_handlers()  # kill -10 <pid> will result in traceback dumped into log
36

37
        is_in_ddp_subprocess = handle_ddp_parent_process()
38

39
        config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir)
40
        if not is_in_ddp_subprocess:
41
            LOGGER.info(OmegaConf.to_yaml(config))
42
            OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml'))
43

44
        checkpoints_dir = os.path.join(os.getcwd(), 'models')
45
        os.makedirs(checkpoints_dir, exist_ok=True)
46

47
        # there is no need to suppress this logger in ddp, because it handles rank on its own
48
        metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd()))
49
        metrics_logger.log_hyperparams(config)
50

51
        training_model = make_training_model(config)
52

53
        trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True)
54
        if need_set_deterministic:
55
            trainer_kwargs['deterministic'] = True
56

57
        trainer = Trainer(
58
            # there is no need to suppress checkpointing in ddp, because it handles rank on its own
59
            callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs),
60
            logger=metrics_logger,
61
            default_root_dir=os.getcwd(),
62
            **trainer_kwargs
63
        )
64
        trainer.fit(training_model)
65
    except KeyboardInterrupt:
66
        LOGGER.warning('Interrupted by user')
67
    except Exception as ex:
68
        LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}')
69
        sys.exit(1)
70

71

72
if __name__ == '__main__':
73
    main()
74

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.