when-to-switch
/
train_epom.py
117 строк · 4.1 Кб
1import json
2import sys
3from argparse import Namespace
4from pathlib import Path
5
6import numpy as np
7import wandb
8import yaml
9from sample_factory.algorithms.utils.algo_utils import EXTRA_EPISODIC_STATS_PROCESSING, EXTRA_PER_POLICY_SUMMARIES
10from sample_factory.envs.env_registry import global_env_registry
11from sample_factory.run_algorithm import run_algorithm
12from sample_factory.utils.utils import log
13
14# noinspection PyUnresolvedReferences
15from learning.encoder import ResnetEncoder
16from learning.epom_config import Environment, Experiment
17from learning.grid_memory import GridMemoryWrapper
18from pomapf_env.env import make_pomapf
19from pomapf_env.wrappers import MatrixObservationWrapper
20
21
22def make_env(env_cfg: Environment = Environment()):
23env = make_pomapf(grid_config=env_cfg.grid_config)
24return env
25
26
27def create_pogema_env(full_env_name, cfg=None, env_config=None):
28environment_config: Environment = Environment(**cfg.full_config['environment'])
29env = make_env(environment_config)
30gm_radius = environment_config.grid_memory_obs_radius
31env = GridMemoryWrapper(env, obs_radius=gm_radius if gm_radius else environment_config.grid_config.obs_radius)
32env = MatrixObservationWrapper(env)
33return env
34
35
36def register_custom_components():
37global_env_registry().register_env(
38env_name_prefix='POMAPF',
39make_env_func=create_pogema_env,
40)
41
42EXTRA_EPISODIC_STATS_PROCESSING.append(pogema_extra_episodic_stats_processing)
43EXTRA_PER_POLICY_SUMMARIES.append(pogema_extra_summaries)
44
45
46def pogema_extra_episodic_stats_processing(policy_id, stat_key, stat_value, cfg):
47pass
48
49
50def pogema_extra_summaries(policy_id, policy_avg_stats, env_steps, summary_writer, cfg):
51for key in policy_avg_stats:
52if key in ['reward', 'len', 'true_reward', 'Done']:
53continue
54
55avg = np.mean(np.array(policy_avg_stats[key][policy_id]))
56summary_writer.add_scalar(key, avg, env_steps)
57log.debug(f'{policy_id}-{key}: {round(float(avg), 3)}')
58
59
60def validate_config(config):
61exp = Experiment(**config)
62flat_config = Namespace(**exp.async_ppo.dict(),
63**exp.experiment_settings.dict(),
64**exp.global_settings.dict(),
65**exp.evaluation.dict(),
66full_config=exp.dict()
67)
68return exp, flat_config
69
70
71def select_free_dir_name(rood_dir, max_id=100000):
72for cnt in range(1, max_id):
73free_folder = f"{cnt}".zfill(4)
74full_path = Path(rood_dir) / Path(free_folder)
75if not full_path.exists():
76return free_folder
77raise KeyError(f"Can't select a folder in {max_id} attempts")
78
79
80def main():
81import argparse
82
83parser = argparse.ArgumentParser(description='Process training config.')
84
85parser.add_argument('--config_path', type=str, action="store", default='configs/train-debug.yaml',
86help='path to yaml file with single run configuration', required=False)
87
88parser.add_argument('--wandb_thread_mode', type=bool, action='store', default=False,
89help='Run wandb in thread mode. Usefull for some setups.', required=False)
90
91params = parser.parse_args()
92
93register_custom_components()
94if params.config_path is None:
95raise ValueError("You should specify --config_path or --raw_config argument!")
96with open(params.config_path, "r") as f:
97config = yaml.safe_load(f)
98
99exp, flat_config = validate_config(config)
100if exp.global_settings.experiments_root is None:
101exp.global_settings.experiments_root = select_free_dir_name(exp.global_settings.train_dir)
102exp, flat_config = validate_config(exp.dict())
103log.debug(exp.global_settings.experiments_root)
104
105if exp.global_settings.use_wandb:
106import os
107if params.wandb_thread_mode:
108os.environ["WANDB_START_METHOD"] = "thread"
109wandb.init(project=exp.name, config=exp.dict(), save_code=False, sync_tensorboard=True, anonymous="allow", )
110
111status = run_algorithm(flat_config)
112
113return status
114
115
116if __name__ == '__main__':
117sys.exit(main())
118