8
os.environ['OMP_NUM_THREADS'] = '1'
9
os.environ['OPENBLAS_NUM_THREADS'] = '1'
10
os.environ['MKL_NUM_THREADS'] = '1'
11
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
12
os.environ['NUMEXPR_NUM_THREADS'] = '1'
15
from omegaconf import OmegaConf
16
from pytorch_lightning import Trainer
17
from pytorch_lightning.callbacks import ModelCheckpoint
18
from pytorch_lightning.loggers import TensorBoardLogger
19
from pytorch_lightning.plugins import DDPPlugin
21
from saicinpainting.training.trainers import make_training_model
22
from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \
23
handle_deterministic_config
25
LOGGER = logging.getLogger(__name__)
28
@handle_ddp_subprocess()
29
@hydra.main(config_path='../configs/training', config_name='tiny_test.yaml')
30
def main(config: OmegaConf):
32
need_set_deterministic = handle_deterministic_config(config)
34
if sys.platform != 'win32':
35
register_debug_signal_handlers()
37
is_in_ddp_subprocess = handle_ddp_parent_process()
39
config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir)
40
if not is_in_ddp_subprocess:
41
LOGGER.info(OmegaConf.to_yaml(config))
42
OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml'))
44
checkpoints_dir = os.path.join(os.getcwd(), 'models')
45
os.makedirs(checkpoints_dir, exist_ok=True)
48
metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd()))
49
metrics_logger.log_hyperparams(config)
51
training_model = make_training_model(config)
53
trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True)
54
if need_set_deterministic:
55
trainer_kwargs['deterministic'] = True
59
callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs),
60
logger=metrics_logger,
61
default_root_dir=os.getcwd(),
64
trainer.fit(training_model)
65
except KeyboardInterrupt:
66
LOGGER.warning('Interrupted by user')
67
except Exception as ex:
68
LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}')
72
if __name__ == '__main__':