when-to-switch

Форк
0
/
train_epom.py 
117 строк · 4.1 Кб
1
import json
2
import sys
3
from argparse import Namespace
4
from pathlib import Path
5

6
import numpy as np
7
import wandb
8
import yaml
9
from sample_factory.algorithms.utils.algo_utils import EXTRA_EPISODIC_STATS_PROCESSING, EXTRA_PER_POLICY_SUMMARIES
10
from sample_factory.envs.env_registry import global_env_registry
11
from sample_factory.run_algorithm import run_algorithm
12
from sample_factory.utils.utils import log
13

14
# noinspection PyUnresolvedReferences
15
from learning.encoder import ResnetEncoder
16
from learning.epom_config import Environment, Experiment
17
from learning.grid_memory import GridMemoryWrapper
18
from pomapf_env.env import make_pomapf
19
from pomapf_env.wrappers import MatrixObservationWrapper
20

21

22
def make_env(env_cfg: Environment = Environment()):
23
    env = make_pomapf(grid_config=env_cfg.grid_config)
24
    return env
25

26

27
def create_pogema_env(full_env_name, cfg=None, env_config=None):
28
    environment_config: Environment = Environment(**cfg.full_config['environment'])
29
    env = make_env(environment_config)
30
    gm_radius = environment_config.grid_memory_obs_radius
31
    env = GridMemoryWrapper(env, obs_radius=gm_radius if gm_radius else environment_config.grid_config.obs_radius)
32
    env = MatrixObservationWrapper(env)
33
    return env
34

35

36
def register_custom_components():
37
    global_env_registry().register_env(
38
        env_name_prefix='POMAPF',
39
        make_env_func=create_pogema_env,
40
    )
41

42
    EXTRA_EPISODIC_STATS_PROCESSING.append(pogema_extra_episodic_stats_processing)
43
    EXTRA_PER_POLICY_SUMMARIES.append(pogema_extra_summaries)
44

45

46
def pogema_extra_episodic_stats_processing(policy_id, stat_key, stat_value, cfg):
47
    pass
48

49

50
def pogema_extra_summaries(policy_id, policy_avg_stats, env_steps, summary_writer, cfg):
51
    for key in policy_avg_stats:
52
        if key in ['reward', 'len', 'true_reward', 'Done']:
53
            continue
54

55
        avg = np.mean(np.array(policy_avg_stats[key][policy_id]))
56
        summary_writer.add_scalar(key, avg, env_steps)
57
        log.debug(f'{policy_id}-{key}: {round(float(avg), 3)}')
58

59

60
def validate_config(config):
61
    exp = Experiment(**config)
62
    flat_config = Namespace(**exp.async_ppo.dict(),
63
                            **exp.experiment_settings.dict(),
64
                            **exp.global_settings.dict(),
65
                            **exp.evaluation.dict(),
66
                            full_config=exp.dict()
67
                            )
68
    return exp, flat_config
69

70

71
def select_free_dir_name(rood_dir, max_id=100000):
72
    for cnt in range(1, max_id):
73
        free_folder = f"{cnt}".zfill(4)
74
        full_path = Path(rood_dir) / Path(free_folder)
75
        if not full_path.exists():
76
            return free_folder
77
    raise KeyError(f"Can't select a folder in {max_id} attempts")
78

79

80
def main():
81
    import argparse
82

83
    parser = argparse.ArgumentParser(description='Process training config.')
84

85
    parser.add_argument('--config_path', type=str, action="store", default='configs/train-debug.yaml',
86
                        help='path to yaml file with single run configuration', required=False)
87

88
    parser.add_argument('--wandb_thread_mode', type=bool, action='store', default=False,
89
                        help='Run wandb in thread mode. Usefull for some setups.', required=False)
90

91
    params = parser.parse_args()
92

93
    register_custom_components()
94
    if params.config_path is None:
95
        raise ValueError("You should specify --config_path or --raw_config argument!")
96
    with open(params.config_path, "r") as f:
97
        config = yaml.safe_load(f)
98

99
    exp, flat_config = validate_config(config)
100
    if exp.global_settings.experiments_root is None:
101
        exp.global_settings.experiments_root = select_free_dir_name(exp.global_settings.train_dir)
102
        exp, flat_config = validate_config(exp.dict())
103
    log.debug(exp.global_settings.experiments_root)
104

105
    if exp.global_settings.use_wandb:
106
        import os
107
        if params.wandb_thread_mode:
108
            os.environ["WANDB_START_METHOD"] = "thread"
109
        wandb.init(project=exp.name, config=exp.dict(), save_code=False, sync_tensorboard=True, anonymous="allow", )
110

111
    status = run_algorithm(flat_config)
112

113
    return status
114

115

116
if __name__ == '__main__':
117
    sys.exit(main())
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.