learn-to-follow
/
example.py
67 строк · 2.5 Кб
1import argparse
2
3from env.create_env import create_env_base
4from env.custom_maps import MAPS_REGISTRY
5from utils.eval_utils import run_episode
6from follower.training_config import EnvironmentMazes
7from follower.inference import FollowerInferenceConfig, FollowerInference
8from follower.preprocessing import follower_preprocessor
9from follower_cpp.inference import FollowerConfigCPP, FollowerInferenceCPP
10from follower_cpp.preprocessing import follower_cpp_preprocessor
11
12
13def create_custom_env(cfg):
14env_cfg = EnvironmentMazes(with_animation=cfg.animation)
15env_cfg.grid_config.num_agents = cfg.num_agents
16env_cfg.grid_config.map_name = cfg.map_name
17env_cfg.grid_config.seed = cfg.seed
18env_cfg.grid_config.max_episode_steps = cfg.max_episode_steps
19return create_env_base(env_cfg)
20
21
22def run_follower(env):
23follower_cfg = FollowerInferenceConfig()
24algo = FollowerInference(follower_cfg)
25
26env = follower_preprocessor(env, follower_cfg)
27
28return run_episode(env, algo)
29
30
31def run_follower_cpp(env):
32follower_cfg = FollowerConfigCPP(path_to_weights='model/follower-lite', num_threads=6)
33algo = FollowerInferenceCPP(follower_cfg)
34
35env = follower_cpp_preprocessor(env, follower_cfg)
36
37return run_episode(env, algo)
38
39
40def main():
41parser = argparse.ArgumentParser(description='Follower Inference Script')
42parser.add_argument('--animation', action='store_false', help='Enable animation (default: %(default)s)')
43parser.add_argument('--num_agents', type=int, default=128, help='Number of agents (default: %(default)d)')
44parser.add_argument('--seed', type=int, default=0, help='Random seed (default: %(default)d)')
45parser.add_argument('--map_name', type=str, default='wfi_warehouse', help='Map name (default: %(default)s)')
46parser.add_argument('--max_episode_steps', type=int, default=256,
47help='Maximum episode steps (default: %(default)d)')
48parser.add_argument('--show_map_names', action='store_true', help='Shows names of all available maps')
49
50parser.add_argument('--algorithm', type=str, choices=['Follower', 'FollowerLite'], default='Follower',
51help='Algorithm to use: "Follower" or "FollowerLite" (default: "Follower")')
52
53args = parser.parse_args()
54
55if args.show_map_names:
56for map_ in MAPS_REGISTRY:
57print(map_)
58return
59
60if args.algorithm == 'FollowerLite':
61print(run_follower_cpp(create_custom_env(args)))
62else: # Default to 'Follower'
63print(run_follower(create_custom_env(args)))
64
65
66if __name__ == '__main__':
67main()
68