learn-to-follow

Форк
0
/
example.py 
67 строк · 2.5 Кб
1
import argparse
2

3
from env.create_env import create_env_base
4
from env.custom_maps import MAPS_REGISTRY
5
from utils.eval_utils import run_episode
6
from follower.training_config import EnvironmentMazes
7
from follower.inference import FollowerInferenceConfig, FollowerInference
8
from follower.preprocessing import follower_preprocessor
9
from follower_cpp.inference import FollowerConfigCPP, FollowerInferenceCPP
10
from follower_cpp.preprocessing import follower_cpp_preprocessor
11

12

13
def create_custom_env(cfg):
14
    env_cfg = EnvironmentMazes(with_animation=cfg.animation)
15
    env_cfg.grid_config.num_agents = cfg.num_agents
16
    env_cfg.grid_config.map_name = cfg.map_name
17
    env_cfg.grid_config.seed = cfg.seed
18
    env_cfg.grid_config.max_episode_steps = cfg.max_episode_steps
19
    return create_env_base(env_cfg)
20

21

22
def run_follower(env):
23
    follower_cfg = FollowerInferenceConfig()
24
    algo = FollowerInference(follower_cfg)
25

26
    env = follower_preprocessor(env, follower_cfg)
27

28
    return run_episode(env, algo)
29

30

31
def run_follower_cpp(env):
32
    follower_cfg = FollowerConfigCPP(path_to_weights='model/follower-lite', num_threads=6)
33
    algo = FollowerInferenceCPP(follower_cfg)
34

35
    env = follower_cpp_preprocessor(env, follower_cfg)
36

37
    return run_episode(env, algo)
38

39

40
def main():
41
    parser = argparse.ArgumentParser(description='Follower Inference Script')
42
    parser.add_argument('--animation', action='store_false', help='Enable animation (default: %(default)s)')
43
    parser.add_argument('--num_agents', type=int, default=128, help='Number of agents (default: %(default)d)')
44
    parser.add_argument('--seed', type=int, default=0, help='Random seed (default: %(default)d)')
45
    parser.add_argument('--map_name', type=str, default='wfi_warehouse', help='Map name (default: %(default)s)')
46
    parser.add_argument('--max_episode_steps', type=int, default=256,
47
                        help='Maximum episode steps (default: %(default)d)')
48
    parser.add_argument('--show_map_names', action='store_true', help='Shows names of all available maps')
49

50
    parser.add_argument('--algorithm', type=str, choices=['Follower', 'FollowerLite'], default='Follower',
51
                        help='Algorithm to use: "Follower" or "FollowerLite" (default: "Follower")')
52

53
    args = parser.parse_args()
54

55
    if args.show_map_names:
56
        for map_ in MAPS_REGISTRY:
57
            print(map_)
58
        return
59

60
    if args.algorithm == 'FollowerLite':
61
        print(run_follower_cpp(create_custom_env(args)))
62
    else:  # Default to 'Follower'
63
        print(run_follower(create_custom_env(args)))
64

65

66
if __name__ == '__main__':
67
    main()
68

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.